diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp index 30e55f5bd..a0fb216e5 100644 --- a/src/storm/generator/PrismNextStateGenerator.cpp +++ b/src/storm/generator/PrismNextStateGenerator.cpp @@ -288,22 +288,18 @@ namespace storm { result.setExpanded(); std::vector> allChoices; - std::vector> allLabeledChoices; if (this->getOptions().isApplyMaximalProgressAssumptionSet()) { // First explore only edges without a rate allChoices = getUnlabeledChoices(*this->state, stateToIdCallback, CommandFilter::Probabilistic); - allLabeledChoices = getLabeledChoices(*this->state, stateToIdCallback, CommandFilter::Probabilistic); - if (allChoices.empty() && allLabeledChoices.empty()) { + addLabeledChoices(allChoices, *this->state, stateToIdCallback, CommandFilter::Probabilistic); + if (allChoices.empty()) { // Expand the Markovian edges if there are no probabilistic ones. allChoices = getUnlabeledChoices(*this->state, stateToIdCallback, CommandFilter::Markovian); - allLabeledChoices = getLabeledChoices(*this->state, stateToIdCallback, CommandFilter::Markovian); + addLabeledChoices(allChoices, *this->state, stateToIdCallback, CommandFilter::Markovian); } } else { allChoices = getUnlabeledChoices(*this->state, stateToIdCallback); - allLabeledChoices = getLabeledChoices(*this->state, stateToIdCallback); - } - for (auto& choice : allLabeledChoices) { - allChoices.push_back(std::move(choice)); + addLabeledChoices(allChoices, *this->state, stateToIdCallback); } std::size_t totalNumberOfChoices = allChoices.size(); @@ -428,11 +424,25 @@ namespace storm { return newState; } + struct ActiveCommandData { + ActiveCommandData(storm::prism::Module const* modulePtr, std::set const* commandIndicesPtr, typename std::set::const_iterator currentCommandIndexIt) : modulePtr(modulePtr), commandIndicesPtr(commandIndicesPtr), currentCommandIndexIt(currentCommandIndexIt) { + // Intentionally left empty + } + storm::prism::Module const* modulePtr; + std::set const* commandIndicesPtr; + typename std::set::const_iterator currentCommandIndexIt; + }; + template boost::optional>>> PrismNextStateGenerator::getActiveCommandsByActionIndex(uint_fast64_t const& actionIndex, CommandFilter const& commandFilter) { - boost::optional>>> result((std::vector>>())); - + + // First check whether there is at least one enabled command at each module + // This avoids evaluating unnecessarily many guards. + // If we find one module without an enabled command, we return boost::none. + // At the same time, we store pointers to the relevant modules, the relevant command sets and the first enabled command within each set. + // Iterate over all modules. + std::vector activeCommands; for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { storm::prism::Module const& module = program.getModule(i); @@ -446,14 +456,47 @@ namespace storm { // If the module contains the action, but there is no command in the module that is labeled with // this action, we don't have any feasible command combinations. if (commandIndices.empty()) { - return boost::optional>>>(); + return boost::none; } + // Look up commands by their indices and check if the guard evaluates to true in the given state. + bool hasOneEnabledCommand = false; + for (auto commandIndexIt = commandIndices.begin(), commandIndexIte = commandIndices.end(); commandIndexIt != commandIndexIte; ++commandIndexIt) { + storm::prism::Command const& command = module.getCommand(*commandIndexIt); + if (commandFilter != CommandFilter::All) { + STORM_LOG_ASSERT(commandFilter == CommandFilter::Markovian || commandFilter == CommandFilter::Probabilistic, "Unexpected command filter."); + if ((commandFilter == CommandFilter::Markovian) != command.isMarkovian()) { + continue; + } + } + if (this->evaluator->asBool(command.getGuardExpression())) { + // Found the first enabled command for this module. + hasOneEnabledCommand = true; + activeCommands.emplace_back(&module, &commandIndices, commandIndexIt); + break; + } + } + + if (!hasOneEnabledCommand) { + return boost::none; + } + } + + // If we reach this point, there has to be at least one active command for each relevant module. + std::vector>> result; + + // Iterate over all command sets. + for (auto const& activeCommand : activeCommands) { std::vector> commands; + auto commandIndexIt = activeCommand.currentCommandIndexIt; + // The command at the current position is already known to be enabled + commands.push_back(activeCommand.modulePtr->getCommand(*commandIndexIt)); + // Look up commands by their indices and add them if the guard evaluates to true in the given state. - for (uint_fast64_t commandIndex : commandIndices) { - storm::prism::Command const& command = module.getCommand(commandIndex); + auto commandIndexIte = activeCommand.commandIndicesPtr->end(); + for (++commandIndexIt; commandIndexIt != commandIndexIte; ++commandIndexIt) { + storm::prism::Command const& command = activeCommand.modulePtr->getCommand(*commandIndexIt); if (commandFilter != CommandFilter::All) { STORM_LOG_ASSERT(commandFilter == CommandFilter::Markovian || commandFilter == CommandFilter::Probabilistic, "Unexpected command filter."); if ((commandFilter == CommandFilter::Markovian) != command.isMarkovian()) { @@ -465,16 +508,10 @@ namespace storm { } } - // If there was no enabled command although the module has some command with the required action label, - // we must not return anything. - if (commands.size() == 0) { - return boost::none; - } - - result.get().push_back(std::move(commands)); + result.push_back(std::move(commands)); } - STORM_LOG_ASSERT(!result->empty(), "Expected non-empty list."); + STORM_LOG_ASSERT(!result.empty(), "Expected non-empty list."); return result; } @@ -576,8 +613,7 @@ namespace storm { } template - std::vector> PrismNextStateGenerator::getLabeledChoices(CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter) { - std::vector> result; + void PrismNextStateGenerator::addLabeledChoices(std::vector>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter) { for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) { boost::optional>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(actionIndex, commandFilter); @@ -604,10 +640,10 @@ namespace storm { // At this point, we applied all commands of the current command combination and newTargetStates // contains all target states and their respective probabilities. That means we are now ready to // add the choice to the list of transitions. - result.push_back(Choice(actionIndex)); + choices.push_back(Choice(actionIndex)); // Now create the actual distribution. - Choice& choice = result.back(); + Choice& choice = choices.back(); // Remember the choice label and origins only if we were asked to. if (this->options.isBuildChoiceLabelsSet()) { @@ -623,6 +659,7 @@ namespace storm { // Add the probabilities/rates to the newly created choice. ValueType probabilitySum = storm::utility::zero(); + choice.reserve(std::distance(distribution.begin(), distribution.end())); for (auto const& stateProbability : distribution) { choice.addProbability(stateProbability.getState(), stateProbability.getValue()); if (this->options.isExplorationChecksSet()) { @@ -664,8 +701,6 @@ namespace storm { } } } - - return result; } template diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h index b2eb1b555..e302fd02e 100644 --- a/src/storm/generator/PrismNextStateGenerator.h +++ b/src/storm/generator/PrismNextStateGenerator.h @@ -95,10 +95,11 @@ namespace storm { /*! * Retrieves all labeled choices possible from the given state. * + * @param choices The new choices are inserted in this vector * @param state The state for which to retrieve the unlabeled choices. * @return The labeled choices of the state. */ - std::vector> getLabeledChoices(CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter = CommandFilter::All); + void addLabeledChoices(std::vector>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter = CommandFilter::All); /*!