diff --git a/src/storage/expressions.cpp b/src/storage/expressions.cpp index 354a2f8..59bfba5 100644 --- a/src/storage/expressions.cpp +++ b/src/storage/expressions.cpp @@ -39,6 +39,8 @@ void define_expressions(py::module& m) { .def("has_numerical_type", &storm::expressions::Variable::hasNumericalType, "Check if the variable is of numerical type") .def("has_bitvector_type", &storm::expressions::Variable::hasBitVectorType, "Check if the variable is of bitvector type") .def("get_expression", &storm::expressions::Variable::getExpression, "Get expression from variable") + .def("__eq__", &storm::expressions::Variable::operator==) + .def("__hash__", &storm::expressions::Variable::getIndex) ; @@ -130,4 +132,4 @@ void define_expressions(py::module& m) { .def_property_readonly("is_rational", &storm::expressions::Type::isRationalType) .def("__str__", &storm::expressions::Type::getStringRepresentation); -} \ No newline at end of file +} diff --git a/src/storage/prism.cpp b/src/storage/prism.cpp index e7d4ca9..e3bab6c 100644 --- a/src/storage/prism.cpp +++ b/src/storage/prism.cpp @@ -1,12 +1,25 @@ #include "prism.h" #include +#include #include "src/helpers.h" #include #include #include +#include +#include +#include +#include "storm/exceptions/NotSupportedException.h" +#include +#include "storm/exceptions/InvalidTypeException.h" +#include "storm/exceptions/InvalidStateException.h" +#include "storm/exceptions/InvalidAccessException.h" + using namespace storm::prism; +template +void define_stateGeneration(py::module& m); + void define_prism(py::module& m) { py::class_> program(m, "PrismProgram", "A Prism Program"); program.def_property_readonly("constants", &Program::getConstants, "Get Program Constants") @@ -38,22 +51,28 @@ void define_prism(py::module& m) { .def_property_readonly("name", &Module::getName, "Name of the module") .def_property_readonly("integer_variables", &Module::getIntegerVariables, "All integer Variables of this module") .def_property_readonly("boolean_variables", &Module::getBooleanVariables, "All boolean Variables of this module") + .def("__str__", &streamToString) ; py::class_ command(m, "PrismCommand", "A command in a Prism program"); command.def_property_readonly("global_index", &Command::getGlobalIndex, "Get global index") .def_property_readonly("guard_expression", &Command::getGuardExpression, "Get guard expression") - .def_property_readonly("updates", &Command::getUpdates, "Updates in the command"); + .def_property_readonly("updates", &Command::getUpdates, "Updates in the command") + .def("__str__", &streamToString) + ; py::class_ update(m, "PrismUpdate", "An update in a Prism command"); update.def_property_readonly("assignments", &Update::getAssignments, "Assignments in the update") .def_property_readonly("probability_expression", &Update::getLikelihoodExpression, "The probability expression for this update") - ;//Added by Kevin + .def("__str__", &streamToString) + ; py::class_ assignment(m, "PrismAssignment", "An assignment in prism"); assignment.def_property_readonly("variable", &Assignment::getVariable, "Variable that is updated") - .def_property_readonly("expression", &Assignment::getExpression, "Expression for the update"); - + .def_property_readonly("expression", &Assignment::getExpression, "Expression for the update") + .def("__str__", &streamToString) + ; + // PrismType @@ -85,9 +104,183 @@ void define_prism(py::module& m) { py::class_> integer_variable(m, "Prism_Integer_Variable", variable, "A program integer variable in a Prism program"); integer_variable.def_property_readonly("lower_bound_expression", &IntegerVariable::getLowerBoundExpression, "The the lower bound expression of this integer variable") .def_property_readonly("upper_bound_expression", &IntegerVariable::getUpperBoundExpression, "The the upper bound expression of this integer variable") + .def("__str__", &streamToString) ; py::class_> boolean_variable(m, "Prism_Boolean_Variable", variable, "A program boolean variable in a Prism program"); + boolean_variable.def("__str__", &streamToString); + + define_stateGeneration(m); +} + +class ValuationMapping { + public: + std::map booleanValues; + std::map integerValues; + std::map rationalValues; + + ValuationMapping(storm::prism::Program const& program, + storm::expressions::SimpleValuation valuation) { + auto const& variables = program.getManager().getVariables(); + for (auto const& variable : variables) { + if (variable.hasBooleanType()) { + booleanValues[variable] = valuation.getBooleanValue(variable); + } else if (variable.hasIntegerType()) { + integerValues[variable] = valuation.getIntegerValue(variable); + } else if (variable.hasRationalType()) { + rationalValues[variable] = valuation.getRationalValue(variable); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, + "Unexpected variable type."); + } + } + } + + std::string toString() const { + std::vector strs; + for (auto const& value : booleanValues) { + std::stringstream sstr; + sstr << value.first.getName() + "="; + sstr << value.second; + strs.push_back(sstr.str()); + } + for (auto const& value : integerValues) { + std::stringstream sstr; + sstr << value.first.getName() + "="; + sstr << value.second; + strs.push_back(sstr.str()); + } + for (auto const& value : rationalValues) { + std::stringstream sstr; + sstr << value.first.getName() + "="; + sstr << value.second; + strs.push_back(sstr.str()); + } + return "[" + boost::join(strs, ",") + "]"; + } +}; + +template +class StateGenerator { + typedef std::unordered_map IdToStateMap; + typedef std::vector> distribution_type; + typedef distribution_type choice_type; // currently we just collapse choices into distributions + typedef std::vector choice_list_type; + + storm::prism::Program const& program; + storm::generator::PrismNextStateGenerator generator; + std::function stateToIdCallback; + // this needs to be below the generator attribute, + // because its initialization depends on the generator being initialized. + // #justcppthings + storm::storage::sparse::StateStorage stateStorage; + bool hasComputedInitialStates = false; + IdToStateMap stateMap; + boost::optional currentStateIndex; + + public: + StateGenerator(storm::prism::Program const& program_) : program(program_), generator(program_), stateStorage(generator.getStateSize()) { + stateToIdCallback = [this] (storm::generator::CompressedState const& state) -> StateType { + StateType newIndex = stateStorage.getNumberOfStates(); + std::pair indexBucketPair = stateStorage.stateToId.findOrAddAndGetBucket(state, newIndex); + StateType index = indexBucketPair.first; + stateMap[index] = state; + return index; + }; + } + + StateType loadInitialState() { + if (!hasComputedInitialStates) { + stateStorage.initialStateIndices = generator.getInitialStates(stateToIdCallback); + hasComputedInitialStates = true; + } + STORM_LOG_THROW(stateStorage.initialStateIndices.size() == 1, storm::exceptions::NotSupportedException, "Currently only models with one initial state are supported."); + StateType initialStateIndex = stateStorage.initialStateIndices.front(); + load(initialStateIndex); + return initialStateIndex; + } + + void load(StateType stateIndex) { + if (currentStateIndex && *currentStateIndex == stateIndex) { + return; + } + auto search = stateMap.find(stateIndex); + if (search == stateMap.end()) { + STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, + "state id not found"); + } + generator.load(search->second); + currentStateIndex = stateIndex; + } + + ValuationMapping currentStateToValuation() { + if (!currentStateIndex) { + STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, + "Initial state not initialized"); + } + auto valuation = generator.toValuation(stateMap[*currentStateIndex]); + return ValuationMapping(program, valuation); + } + + bool satisfies(storm::expressions::Expression const& expression) { + return generator.satisfies(expression); + } + + choice_list_type expand() { + if (!hasComputedInitialStates) { + STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, + "Initial state not initialized"); + } + choice_list_type choices_result; + auto behavior = generator.expand(stateToIdCallback); + for (auto choice : behavior.getChoices()) { + choices_result.push_back(choice_type(choice.begin(), choice.end())); + } + + return choices_result; + } + +}; + +template +void define_stateGeneration(py::module& m) { + py::class_> valuation_mapping(m, "ValuationMapping", "A valuation mapping for a state consists of a mapping from variable to value for each of the three types."); + valuation_mapping + .def_readonly("boolean_values", &ValuationMapping::booleanValues) + .def_readonly("integer_values", &ValuationMapping::integerValues) + .def_readonly("rational_values", &ValuationMapping::rationalValues) + .def("__str__", &ValuationMapping::toString); + + py::class_, std::shared_ptr>> state_generator(m, "StateGenerator", R"doc( + Interactively explore states using Storm's PrismNextStateGenerator. + + :ivar program: A PRISM program. + )doc"); + state_generator + .def(py::init()) + .def("load_initial_state", &StateGenerator::loadInitialState, R"doc( + Loads the (unique) initial state. + Multiple initial states are not supported. + + :rtype: the ID of the initial state. + )doc") + .def("load", &StateGenerator::load, R"doc( + :param state_id: The ID of the state to load. + )doc") + .def("current_state_to_valuation", &StateGenerator::currentStateToValuation, R"doc( + Return a valuation for the currently loaded state. + + :rtype: stormpy.ValuationMapping + )doc") + .def("current_state_satisfies", &StateGenerator::satisfies, R"doc( + Check if the currently loaded state satisfies the given expression. + :param stormpy.Expression expression: The expression to check against. + :rtype: bool + )doc") + .def("expand", &StateGenerator::expand, R"doc( + Expand the currently loaded state and return its successors. -} \ No newline at end of file + :rtype: [[(state_id, probability)]] + )doc"); +} diff --git a/src/storage/valuation.cpp b/src/storage/valuation.cpp index 36c9699..212a01d 100644 --- a/src/storage/valuation.cpp +++ b/src/storage/valuation.cpp @@ -18,4 +18,5 @@ void define_simplevaluation(py::module& m) { py::class_> simplevaluation(m, "SimpleValuation", "Valuations for storm variables"); simplevaluation.def("get_boolean_value", &storm::expressions::SimpleValuation::getBooleanValue); simplevaluation.def("get_integer_value", &storm::expressions::SimpleValuation::getIntegerValue); -} \ No newline at end of file + simplevaluation.def("__str__", &storm::expressions::SimpleValuation::toString); +} diff --git a/tests/storage/test_state_generation.py b/tests/storage/test_state_generation.py new file mode 100644 index 0000000..f61b11c --- /dev/null +++ b/tests/storage/test_state_generation.py @@ -0,0 +1,75 @@ +import stormpy +from stormpy.examples.files import prism_dtmc_die + +class _DfsQueue: + def __init__(self): + self.queue = [] + self.visited = set() + + def push(self, state_id): + if state_id not in self.visited: + self.queue.append(state_id) + self.visited.add(state_id) + + def pop(self): + if len(self.queue) > 0: + return self.queue.pop() + return None + + +def _dfs_explore(program, callback): + generator = stormpy.StateGenerator(program) + + queue = _DfsQueue() + current_state_id = generator.load_initial_state() + queue.visited.add(current_state_id) + + while True: + callback(current_state_id, generator) + + successors = generator.expand() + assert len(successors) <= 1 + for choice in successors: + for state_id, _prob in choice: + queue.push(state_id) + current_state_id = queue.pop() + if current_state_id is None: + break + generator.load(current_state_id) + + +def _load_program(filename): + program = stormpy.parse_prism_program(filename) # pylint: disable=no-member + program = program.substitute_constants() + + expression_parser = stormpy.ExpressionParser(program.expression_manager) + expression_parser.set_identifier_mapping( + {var.name: var.get_expression() + for var in program.variables}) + return program, expression_parser + + +def _find_variable(program, name): + for var in program.variables: + if var.name is name: + return var + return None + +def test_knuth_yao_die(): + program, expression_parser = _load_program(prism_dtmc_die) + s_variable = _find_variable(program, "s") + upper_bound_invariant = expression_parser.parse("d <= 6") + + number_states = 0 + + def callback(_state_id, generator): + nonlocal number_states + number_states += 1 + + valuation = generator.current_state_to_valuation() + print(valuation) + assert valuation.integer_values[s_variable] <= 7 + assert generator.current_state_satisfies(upper_bound_invariant) + + _dfs_explore(program, callback) + assert number_states == 13