diff --git a/src/storm/generator/JaniNextStateGenerator.cpp b/src/storm/generator/JaniNextStateGenerator.cpp index 883abecad..f052133a9 100644 --- a/src/storm/generator/JaniNextStateGenerator.cpp +++ b/src/storm/generator/JaniNextStateGenerator.cpp @@ -171,59 +171,118 @@ namespace storm { template std::vector JaniNextStateGenerator::getInitialStates(StateToIdCallback const& stateToIdCallback) { - // Prepare an SMT solver to enumerate all initial states. - storm::utility::solver::SmtSolverFactory factory; - std::unique_ptr solver = factory.create(model.getExpressionManager()); - - std::vector rangeExpressions = model.getAllRangeExpressions(this->parallelAutomata); - for (auto const& expression : rangeExpressions) { - solver->add(expression); - } - solver->add(model.getInitialStatesExpression(this->parallelAutomata)); - - // Proceed as long as the solver can still enumerate initial states. std::vector initialStateIndices; - while (solver->check() == storm::solver::SmtSolver::CheckResult::Sat) { - // Create fresh state. + + if (this->model.hasNonTrivialInitialStatesRestriction()) { + // Prepare an SMT solver to enumerate all initial states. + storm::utility::solver::SmtSolverFactory factory; + std::unique_ptr solver = factory.create(model.getExpressionManager()); + + std::vector rangeExpressions = model.getAllRangeExpressions(this->parallelAutomata); + for (auto const& expression : rangeExpressions) { + solver->add(expression); + } + solver->add(model.getInitialStatesExpression(this->parallelAutomata)); + + // Proceed as long as the solver can still enumerate initial states. + while (solver->check() == storm::solver::SmtSolver::CheckResult::Sat) { + // Create fresh state. + CompressedState initialState(this->variableInformation.getTotalBitOffset(true)); + + // Read variable assignment from the solution of the solver. Also, create an expression we can use to + // prevent the variable assignment from being enumerated again. + storm::expressions::Expression blockingExpression; + std::shared_ptr model = solver->getModel(); + for (auto const& booleanVariable : this->variableInformation.booleanVariables) { + bool variableValue = model->getBooleanValue(booleanVariable.variable); + storm::expressions::Expression localBlockingExpression = variableValue ? !booleanVariable.variable : booleanVariable.variable; + blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; + initialState.set(booleanVariable.bitOffset, variableValue); + } + for (auto const& integerVariable : this->variableInformation.integerVariables) { + int_fast64_t variableValue = model->getIntegerValue(integerVariable.variable); + storm::expressions::Expression localBlockingExpression = integerVariable.variable != model->getManager().integer(variableValue); + blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; + initialState.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, static_cast(variableValue - integerVariable.lowerBound)); + } + + // Gather iterators to the initial locations of all the automata. + std::vector::const_iterator> initialLocationsIts; + std::vector::const_iterator> initialLocationsItes; + for (auto const& automatonRef : this->parallelAutomata) { + auto const& automaton = automatonRef.get(); + initialLocationsIts.push_back(automaton.getInitialLocationIndices().cbegin()); + initialLocationsItes.push_back(automaton.getInitialLocationIndices().cend()); + } + storm::utility::combinatorics::forEach(initialLocationsIts, initialLocationsItes, [this,&initialState] (uint64_t index, uint64_t value) { setLocation(initialState, this->variableInformation.locationVariables[index], value); }, [&stateToIdCallback,&initialStateIndices,&initialState] () { + // Register initial state. + StateType id = stateToIdCallback(initialState); + initialStateIndices.push_back(id); + return true; + }); + + // Block the current initial state to search for the next one. + if (!blockingExpression.isInitialized()) { + break; + } + solver->add(blockingExpression); + } + + STORM_LOG_DEBUG("Enumerated " << initialStateIndices.size() << " initial states using SMT solving."); + } else { CompressedState initialState(this->variableInformation.getTotalBitOffset(true)); - // Read variable assignment from the solution of the solver. Also, create an expression we can use to - // prevent the variable assignment from being enumerated again. - storm::expressions::Expression blockingExpression; - std::shared_ptr model = solver->getModel(); - for (auto const& booleanVariable : this->variableInformation.booleanVariables) { - bool variableValue = model->getBooleanValue(booleanVariable.variable); - storm::expressions::Expression localBlockingExpression = variableValue ? !booleanVariable.variable : booleanVariable.variable; - blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; - initialState.set(booleanVariable.bitOffset, variableValue); - } - for (auto const& integerVariable : this->variableInformation.integerVariables) { - int_fast64_t variableValue = model->getIntegerValue(integerVariable.variable); - storm::expressions::Expression localBlockingExpression = integerVariable.variable != model->getManager().integer(variableValue); - blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; - initialState.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, static_cast(variableValue - integerVariable.lowerBound)); - } - - // Gather iterators to the initial locations of all the automata. - std::vector::const_iterator> initialLocationsIts; - std::vector::const_iterator> initialLocationsItes; - for (auto const& automatonRef : this->parallelAutomata) { - auto const& automaton = automatonRef.get(); - initialLocationsIts.push_back(automaton.getInitialLocationIndices().cbegin()); - initialLocationsItes.push_back(automaton.getInitialLocationIndices().cend()); - } - storm::utility::combinatorics::forEach(initialLocationsIts, initialLocationsItes, [this,&initialState] (uint64_t index, uint64_t value) { setLocation(initialState, this->variableInformation.locationVariables[index], value); }, [&stateToIdCallback,&initialStateIndices,&initialState] () { - // Register initial state. - StateType id = stateToIdCallback(initialState); - initialStateIndices.push_back(id); - return true; - }); - - // Block the current initial state to search for the next one. - if (!blockingExpression.isInitialized()) { - break; + std::vector currentIntegerValues; + currentIntegerValues.reserve(this->variableInformation.integerVariables.size()); + for (auto const& variable : this->variableInformation.integerVariables) { + STORM_LOG_THROW(variable.lowerBound <= variable.upperBound, storm::exceptions::InvalidArgumentException, "Expecting variable with non-empty set of possible values."); + currentIntegerValues.emplace_back(0); + initialState.setFromInt(variable.bitOffset, variable.bitWidth, 0); } - solver->add(blockingExpression); + + initialStateIndices.emplace_back(stateToIdCallback(initialState)); + + bool done = false; + while (!done) { + bool changedBooleanVariable = false; + for (auto const& booleanVariable : this->variableInformation.booleanVariables) { + if (initialState.get(booleanVariable.bitOffset)) { + initialState.set(booleanVariable.bitOffset); + changedBooleanVariable = true; + break; + } else { + initialState.set(booleanVariable.bitOffset, false); + } + } + + bool changedIntegerVariable = false; + if (changedBooleanVariable) { + initialStateIndices.emplace_back(stateToIdCallback(initialState)); + } else { + for (uint64_t integerVariableIndex = 0; integerVariableIndex < this->variableInformation.integerVariables.size(); ++integerVariableIndex) { + auto const& integerVariable = this->variableInformation.integerVariables[integerVariableIndex]; + if (currentIntegerValues[integerVariableIndex] < integerVariable.upperBound - integerVariable.lowerBound) { + ++currentIntegerValues[integerVariableIndex]; + changedIntegerVariable = true; + } else { + currentIntegerValues[integerVariableIndex] = integerVariable.lowerBound; + } + initialState.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, currentIntegerValues[integerVariableIndex]); + + if (changedIntegerVariable) { + break; + } + } + } + + if (changedIntegerVariable) { + initialStateIndices.emplace_back(stateToIdCallback(initialState)); + } + + done = !changedBooleanVariable && !changedIntegerVariable; + } + + STORM_LOG_DEBUG("Enumerated " << initialStateIndices.size() << " initial states using brute force enumeration."); } return initialStateIndices; diff --git a/src/storm/storage/jani/Automaton.cpp b/src/storm/storage/jani/Automaton.cpp index a3a3be400..5e97ebed3 100644 --- a/src/storm/storage/jani/Automaton.cpp +++ b/src/storm/storage/jani/Automaton.cpp @@ -286,8 +286,6 @@ namespace storm { for (uint64_t locationIndex = edge.getSourceLocationIndex() + 1; locationIndex < locationToStartingIndex.size(); ++locationIndex) { ++locationToStartingIndex[locationIndex]; } - - } std::vector& Automaton::getEdges() { @@ -325,6 +323,10 @@ namespace storm { return initialStatesRestriction.isInitialized(); } + bool Automaton::hasNonTrivialInitialStatesRestriction() const { + return this->hasInitialStatesRestriction() && !this->getInitialStatesRestriction().isTrue(); + } + storm::expressions::Expression const& Automaton::getInitialStatesRestriction() const { return initialStatesRestriction; } diff --git a/src/storm/storage/jani/Automaton.h b/src/storm/storage/jani/Automaton.h index 375164c0c..2b5a833a6 100644 --- a/src/storm/storage/jani/Automaton.h +++ b/src/storm/storage/jani/Automaton.h @@ -241,6 +241,11 @@ namespace storm { */ bool hasInitialStatesRestriction() const; + /*! + * Retrieves whether there is a non-trivial initial states restriction. + */ + bool hasNonTrivialInitialStatesRestriction() const; + /*! * Gets the expression restricting the legal initial values of the automaton's variables. */ diff --git a/src/storm/storage/jani/Model.cpp b/src/storm/storage/jani/Model.cpp index 216c5ced9..32af102ec 100644 --- a/src/storm/storage/jani/Model.cpp +++ b/src/storm/storage/jani/Model.cpp @@ -939,6 +939,20 @@ namespace storm { return initialStatesRestriction; } + bool Model::hasNonTrivialInitialStatesRestriction() const { + if (this->hasInitialStatesRestriction() && !this->getInitialStatesRestriction().isTrue()) { + return true; + } else { + for (auto const& automaton : this->automata) { + if (automaton.hasInitialStatesRestriction() && !automaton.getInitialStatesRestriction().isTrue()) { + return true; + } + } + } + + return false; + } + storm::expressions::Expression Model::getInitialStatesExpression() const { std::vector> allAutomata; for (auto const& automaton : this->getAutomata()) { diff --git a/src/storm/storage/jani/Model.h b/src/storm/storage/jani/Model.h index a6663b41f..70f6d9637 100644 --- a/src/storm/storage/jani/Model.h +++ b/src/storm/storage/jani/Model.h @@ -357,6 +357,12 @@ namespace storm { */ bool hasInitialStatesRestriction() const; + /*! + * Retrieves whether there is a non-trivial initial states restriction in the model or any of the contained + * automata. + */ + bool hasNonTrivialInitialStatesRestriction() const; + /*! * Sets the expression restricting the legal initial values of the global variables. */