diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp index 9518b2d34..8ee493696 100644 --- a/src/storm/generator/PrismNextStateGenerator.cpp +++ b/src/storm/generator/PrismNextStateGenerator.cpp @@ -136,48 +136,108 @@ namespace storm { template std::vector PrismNextStateGenerator::getInitialStates(StateToIdCallback const& stateToIdCallback) { - // Prepare an SMT solver to enumerate all initial states. - storm::utility::solver::SmtSolverFactory factory; - std::unique_ptr solver = factory.create(program.getManager()); - - std::vector rangeExpressions = program.getAllRangeExpressions(); - for (auto const& expression : rangeExpressions) { - solver->add(expression); - } - solver->add(program.getInitialStatesExpression()); - - // Proceed ss long as the solver can still enumerate initial states. std::vector initialStateIndices; - while (solver->check() == storm::solver::SmtSolver::CheckResult::Sat) { - // Create fresh state. + + // If all states are initial, we can simplify the enumeration substantially. + if (program.hasInitialConstruct() && program.getInitialConstruct().getInitialStatesExpression().isTrue()) { CompressedState initialState(this->variableInformation.getTotalBitOffset(true)); - - // Read variable assignment from the solution of the solver. Also, create an expression we can use to - // prevent the variable assignment from being enumerated again. - storm::expressions::Expression blockingExpression; - std::shared_ptr model = solver->getModel(); - for (auto const& booleanVariable : this->variableInformation.booleanVariables) { - bool variableValue = model->getBooleanValue(booleanVariable.variable); - storm::expressions::Expression localBlockingExpression = variableValue ? !booleanVariable.variable : booleanVariable.variable; - blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; - initialState.set(booleanVariable.bitOffset, variableValue); + + std::vector currentIntegerValues; + currentIntegerValues.reserve(this->variableInformation.integerVariables.size()); + for (auto const& variable : this->variableInformation.integerVariables) { + STORM_LOG_THROW(variable.lowerBound <= variable.upperBound, storm::exceptions::InvalidArgumentException, "Expecting variable with non-empty set of possible values."); + currentIntegerValues.emplace_back(0); + initialState.setFromInt(variable.bitOffset, variable.bitWidth, 0); } - for (auto const& integerVariable : this->variableInformation.integerVariables) { - int_fast64_t variableValue = model->getIntegerValue(integerVariable.variable); - storm::expressions::Expression localBlockingExpression = integerVariable.variable != model->getManager().integer(variableValue); - blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; - initialState.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, static_cast(variableValue - integerVariable.lowerBound)); + + initialStateIndices.emplace_back(stateToIdCallback(initialState)); + + bool done = false; + while (!done) { + bool changedBooleanVariable = false; + for (auto const& booleanVariable : this->variableInformation.booleanVariables) { + if (initialState.get(booleanVariable.bitOffset)) { + initialState.set(booleanVariable.bitOffset); + changedBooleanVariable = true; + break; + } else { + initialState.set(booleanVariable.bitOffset, false); + } + } + + bool changedIntegerVariable = false; + if (changedBooleanVariable) { + initialStateIndices.emplace_back(stateToIdCallback(initialState)); + } else { + for (uint64_t integerVariableIndex = 0; integerVariableIndex < this->variableInformation.integerVariables.size(); ++integerVariableIndex) { + auto const& integerVariable = this->variableInformation.integerVariables[integerVariableIndex]; + if (currentIntegerValues[integerVariableIndex] < integerVariable.upperBound - integerVariable.lowerBound) { + ++currentIntegerValues[integerVariableIndex]; + changedIntegerVariable = true; + } else { + currentIntegerValues[integerVariableIndex] = integerVariable.lowerBound; + } + initialState.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, currentIntegerValues[integerVariableIndex]); + + if (changedIntegerVariable) { + break; + } + } + } + + if (changedIntegerVariable) { + initialStateIndices.emplace_back(stateToIdCallback(initialState)); + } + + done = !changedBooleanVariable && !changedIntegerVariable; } - // Register initial state and return it. - StateType id = stateToIdCallback(initialState); - initialStateIndices.push_back(id); + STORM_LOG_DEBUG("Enumerated " << initialStateIndices.size() << " initial states using brute force enumeration."); + } else { + // Prepare an SMT solver to enumerate all initial states. + storm::utility::solver::SmtSolverFactory factory; + std::unique_ptr solver = factory.create(program.getManager()); - // Block the current initial state to search for the next one. - if (!blockingExpression.isInitialized()) { - break; + std::vector rangeExpressions = program.getAllRangeExpressions(); + for (auto const& expression : rangeExpressions) { + solver->add(expression); } - solver->add(blockingExpression); + solver->add(program.getInitialStatesExpression()); + + // Proceed ss long as the solver can still enumerate initial states. + while (solver->check() == storm::solver::SmtSolver::CheckResult::Sat) { + // Create fresh state. + CompressedState initialState(this->variableInformation.getTotalBitOffset(true)); + + // Read variable assignment from the solution of the solver. Also, create an expression we can use to + // prevent the variable assignment from being enumerated again. + storm::expressions::Expression blockingExpression; + std::shared_ptr model = solver->getModel(); + for (auto const& booleanVariable : this->variableInformation.booleanVariables) { + bool variableValue = model->getBooleanValue(booleanVariable.variable); + storm::expressions::Expression localBlockingExpression = variableValue ? !booleanVariable.variable : booleanVariable.variable; + blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; + initialState.set(booleanVariable.bitOffset, variableValue); + } + for (auto const& integerVariable : this->variableInformation.integerVariables) { + int_fast64_t variableValue = model->getIntegerValue(integerVariable.variable); + storm::expressions::Expression localBlockingExpression = integerVariable.variable != model->getManager().integer(variableValue); + blockingExpression = blockingExpression.isInitialized() ? blockingExpression || localBlockingExpression : localBlockingExpression; + initialState.setFromInt(integerVariable.bitOffset, integerVariable.bitWidth, static_cast(variableValue - integerVariable.lowerBound)); + } + + // Register initial state and return it. + StateType id = stateToIdCallback(initialState); + initialStateIndices.push_back(id); + + // Block the current initial state to search for the next one. + if (!blockingExpression.isInitialized()) { + break; + } + solver->add(blockingExpression); + } + + STORM_LOG_DEBUG("Enumerated " << initialStateIndices.size() << " initial states using SMT solving."); } return initialStateIndices; @@ -454,67 +514,60 @@ namespace storm { return result; } - + + template + void PrismNextStateGenerator::generateSynchronizedDistribution(storm::storage::BitVector const& state, ValueType const& probability, uint64_t position, std::vector>::const_iterator> const& iteratorList, storm::builder::jit::Distribution& distribution, StateToIdCallback stateToIdCallback) { + + if (storm::utility::isZero(probability)) { + return; + } + + if (position >= iteratorList.size()) { + StateType id = stateToIdCallback(state); + distribution.add(id, probability); + } else { + storm::prism::Command const& command = *iteratorList[position]; + for (uint_fast64_t j = 0; j < command.getNumberOfUpdates(); ++j) { + storm::prism::Update const& update = command.getUpdate(j); + generateSynchronizedDistribution(applyUpdate(state, update), probability * this->evaluator->asRational(update.getLikelihoodExpression()), position + 1, iteratorList, distribution, stateToIdCallback); + } + } + } + template std::vector> PrismNextStateGenerator::getLabeledChoices(CompressedState const& state, StateToIdCallback stateToIdCallback) { std::vector> result; - storm::builder::jit::Distribution currentDistribution; - storm::builder::jit::Distribution nextDistribution; - for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) { boost::optional>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(actionIndex); - + // Only process this action label, if there is at least one feasible solution. if (optionalActiveCommandLists) { std::vector>> const& activeCommandList = optionalActiveCommandLists.get(); std::vector>::const_iterator> iteratorList(activeCommandList.size()); - + // Initialize the list of iterators. for (size_t i = 0; i < activeCommandList.size(); ++i) { iteratorList[i] = activeCommandList[i].cbegin(); } - + + storm::builder::jit::Distribution distribution; + // As long as there is one feasible combination of commands, keep on expanding it. bool done = false; while (!done) { - currentDistribution.clear(); - nextDistribution.clear(); + distribution.clear(); + generateSynchronizedDistribution(state, storm::utility::one(), 0, iteratorList, distribution, stateToIdCallback); + distribution.compress(); - currentDistribution.add(state, storm::utility::one()); - - for (uint_fast64_t i = 0; i < iteratorList.size(); ++i) { - storm::prism::Command const& command = *iteratorList[i]; - for (uint_fast64_t j = 0; j < command.getNumberOfUpdates(); ++j) { - storm::prism::Update const& update = command.getUpdate(j); - - for (auto const& stateProbability : currentDistribution) { - ValueType probability = stateProbability.getValue() * this->evaluator->asRational(update.getLikelihoodExpression()); - - if (!storm::utility::isZero(probability)) { - // Compute the new state under the current update and add it to the set of new target states. - CompressedState newTargetState = applyUpdate(stateProbability.getState(), update); - nextDistribution.add(newTargetState, probability); - } - } - } - - nextDistribution.compress(); - - // If there is one more command to come, shift the target states one time step back. - if (i < iteratorList.size() - 1) { - currentDistribution = std::move(nextDistribution); - } - } - // At this point, we applied all commands of the current command combination and newTargetStates // contains all target states and their respective probabilities. That means we are now ready to // add the choice to the list of transitions. result.push_back(Choice(actionIndex)); - + // Now create the actual distribution. Choice& choice = result.back(); - + // Remember the choice label and origins only if we were asked to. if (this->options.isBuildChoiceLabelsSet()) { choice.addLabel(program.getActionName(actionIndex)); @@ -526,22 +579,21 @@ namespace storm { } choice.addOriginData(boost::any(std::move(commandIndices))); } - + // Add the probabilities/rates to the newly created choice. ValueType probabilitySum = storm::utility::zero(); - for (auto const& stateProbability : nextDistribution) { - StateType actualIndex = stateToIdCallback(stateProbability.getState()); - choice.addProbability(actualIndex, stateProbability.getValue()); + for (auto const& stateProbability : distribution) { + choice.addProbability(stateProbability.getState(), stateProbability.getValue()); if (this->options.isExplorationChecksSet()) { probabilitySum += stateProbability.getValue(); } } - + if (this->options.isExplorationChecksSet()) { // Check that the resulting distribution is in fact a distribution. STORM_LOG_THROW(!program.isDiscreteTimeModel() || !this->comparator.isConstant(probabilitySum) || this->comparator.isOne(probabilitySum), storm::exceptions::WrongFormatException, "Sum of update probabilities do not some to one for some command (actually sum to " << probabilitySum << ")."); } - + // Create the state-action reward for the newly created choice. for (auto const& rewardModel : rewardModels) { ValueType stateActionRewardValue = storm::utility::zero(); @@ -554,7 +606,7 @@ namespace storm { } choice.addReward(stateActionRewardValue); } - + // Now, check whether there is one more command combination to consider. bool movedIterator = false; for (int_fast64_t j = iteratorList.size() - 1; !movedIterator && j >= 0; --j) { @@ -566,12 +618,12 @@ namespace storm { iteratorList[j] = activeCommandList[j].begin(); } } - + done = !movedIterator; } } } - + return result; } diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h index 5cc8ce9ba..7861dfc5d 100644 --- a/src/storm/generator/PrismNextStateGenerator.h +++ b/src/storm/generator/PrismNextStateGenerator.h @@ -8,6 +8,13 @@ #include "storm/storage/prism/Program.h" namespace storm { + namespace builder { + namespace jit { + template + class Distribution; + } + } + namespace generator { template @@ -85,6 +92,11 @@ namespace storm { */ std::vector> getLabeledChoices(CompressedState const& state, StateToIdCallback stateToIdCallback); + /*! + * A recursive helper function to generate a synchronziing distribution. + */ + void generateSynchronizedDistribution(storm::storage::BitVector const& state, ValueType const& probability, uint64_t position, std::vector>::const_iterator> const& iteratorList, storm::builder::jit::Distribution& distribution, StateToIdCallback stateToIdCallback); + // The program used for the generation of next states. storm::prism::Program program; diff --git a/src/storm/generator/VariableInformation.cpp b/src/storm/generator/VariableInformation.cpp index b093e0234..13d3caf71 100644 --- a/src/storm/generator/VariableInformation.cpp +++ b/src/storm/generator/VariableInformation.cpp @@ -30,15 +30,13 @@ namespace storm { } VariableInformation::VariableInformation(storm::prism::Program const& program, bool outOfBoundsState) : totalBitOffset(0) { - if(outOfBoundsState) { + if (outOfBoundsState) { outOfBoundsBit = 0; ++totalBitOffset; } else { outOfBoundsBit = boost::none; } - - for (auto const& booleanVariable : program.getGlobalBooleanVariables()) { booleanVariables.emplace_back(booleanVariable.getExpressionVariable(), totalBitOffset, true); ++totalBitOffset; @@ -77,15 +75,13 @@ namespace storm { STORM_LOG_THROW(!automaton.getVariables().containsNonTransientUnboundedIntegerVariables(), storm::exceptions::InvalidArgumentException, "Cannot build model from JANI model that contains non-transient unbounded integer variables in automaton '" << automaton.getName() << "'."); STORM_LOG_THROW(!automaton.getVariables().containsNonTransientRealVariables(), storm::exceptions::InvalidArgumentException, "Cannot build model from JANI model that contains non-transient real variables in automaton '" << automaton.getName() << "'."); } - if(outOfBoundsState) { + if (outOfBoundsState) { outOfBoundsBit = 0; ++totalBitOffset; } else { outOfBoundsBit = boost::none; } - - for (auto const& variable : model.getGlobalVariables().getBooleanVariables()) { if (!variable.isTransient()) { booleanVariables.emplace_back(variable.getExpressionVariable(), totalBitOffset, true); diff --git a/src/storm/storage/expressions/Variable.cpp b/src/storm/storage/expressions/Variable.cpp index 675f1caef..0cccde8e0 100644 --- a/src/storm/storage/expressions/Variable.cpp +++ b/src/storm/storage/expressions/Variable.cpp @@ -16,7 +16,11 @@ namespace storm { } bool Variable::operator==(Variable const& other) const { +#ifndef NDEBUG return &this->getManager() == &other.getManager() && index == other.index; +#else + return index == other.index; +#endif } bool Variable::operator!=(Variable const& other) const {