diff --git a/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.cpp b/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.cpp index 135d41071..405fe4fc6 100644 --- a/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.cpp +++ b/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.cpp @@ -31,13 +31,60 @@ namespace storm { subformula = std::make_shared(storm::logic::Formula::getTrueFormula(), subformula->asEventuallyFormula().getSubformula().asSharedPointer()); } - if (subformula->isUntilFormula()) { + if (formula.isProbabilityOperatorFormula() && subformula->isUntilFormula()) { if (!minimizes) { return transformMax(subformula->asUntilFormula()); } + } else if (formula.isRewardOperatorFormula() && subformula->isEventuallyFormula()) { + if (minimizes) { + return transformMinReward(subformula->asEventuallyFormula()); + } } STORM_LOG_THROW(false, storm::exceptions::InvalidPropertyException, "Mec elimination is not supported for the property " << formula); + return nullptr; + } + + template + std::shared_ptr> GlobalPomdpMecChoiceEliminator::transformMinReward(storm::logic::EventuallyFormula const& formula) const { + assert (formula.isRewardPathFormula()); + auto backwardTransitions = pomdp.getBackwardTransitions(); + storm::storage::BitVector allStates(pomdp.getNumberOfStates(), true); + auto prob1EStates = storm::utility::graph::performProb1E(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), backwardTransitions, allStates, checkPropositionalFormula(formula.getSubformula())); + STORM_LOG_THROW(prob1EStates.full(), storm::exceptions::InvalidPropertyException, "There are states from which the set of target states is not reachable. This is not supported."); + auto prob1AStates = storm::utility::graph::performProb1A(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), backwardTransitions, allStates, checkPropositionalFormula(formula.getSubformula())); + + auto mecs = decomposeEndComponents(~prob1AStates, ~allStates); + + // Get the 'out' state for every MEC with just a single out state + storm::storage::BitVector uniqueOutStates = getEndComponentsWithSingleOutStates(mecs); + + // For each observation of some 'out' state get the intersection of the choices that lead to the corresponding MEC + std::vector mecChoicesPerObservation = getEndComponentChoicesPerObservation(mecs, uniqueOutStates); + + // Filter the observations that have a state that is not an out state + storm::storage::BitVector stateFilter = ~uniqueOutStates; + for (auto const& state : stateFilter) { + mecChoicesPerObservation[pomdp.getObservation(state)].clear(); + } + + // It should not be possible to clear all choices for an observation since we only consider states that lead outside of its MEC. + for (auto& clearedChoices : mecChoicesPerObservation) { + STORM_LOG_ASSERT(clearedChoices.size() == 0 || !clearedChoices.full(), "Tried to clear all choices for an observation."); + } + + // transform the set of selected choices to global choice indices + storm::storage::BitVector choiceFilter(pomdp.getNumberOfChoices(), true); + stateFilter.complement(); + for (auto const& state : stateFilter) { + uint64_t offset = pomdp.getTransitionMatrix().getRowGroupIndices()[state]; + for (auto const& choice : mecChoicesPerObservation[pomdp.getObservation(state)]) { + choiceFilter.set(offset + choice, false); + } + } + + ChoiceSelector cs(pomdp); + return cs.transform(choiceFilter)->template as>(); } template @@ -46,45 +93,15 @@ namespace storm { auto prob01States = storm::utility::graph::performProb01Max(pomdp.getTransitionMatrix(), pomdp.getTransitionMatrix().getRowGroupIndices(), backwardTransitions, checkPropositionalFormula(formula.getLeftSubformula()), checkPropositionalFormula(formula.getRightSubformula())); auto mecs = decomposeEndComponents(~(prob01States.first | prob01States.second), prob01States.first); - std::vector mecChoicesPerObservation(pomdp.getNrObservations()); - storm::storage::BitVector uniqueOutStates(pomdp.getNumberOfStates(), false); - // Find the MECs that have only one 'out' state - for (auto const& mec : mecs) { - boost::optional uniqueOutState = boost::none; - for (auto const& stateActionsPair : mec) { - // Check whether this is an 'out' state, i.e., not all actions stay inside the MEC - if (stateActionsPair.second.size() != pomdp.getNumberOfChoices(stateActionsPair.first)) { - if (uniqueOutState) { - // we already found one out state, so this mec is invalid - uniqueOutState = boost::none; - break; - } else { - uniqueOutState = stateActionsPair.first; - } - } - } - if (uniqueOutState) { - uniqueOutStates.set(uniqueOutState.get(), true); - - storm::storage::BitVector localChoiceIndices(pomdp.getNumberOfChoices(uniqueOutState.get()), false); - uint64_t offset = pomdp.getTransitionMatrix().getRowGroupIndices()[uniqueOutState.get()]; - for (auto const& choice : mec.getChoicesForState(uniqueOutState.get())) { - assert(choice >= offset); - localChoiceIndices.set(choice - offset, true); - } - - auto& mecChoices = mecChoicesPerObservation[pomdp.getObservation(uniqueOutState.get())]; - if (mecChoices.size() == 0) { - mecChoices = localChoiceIndices; - } else { - STORM_LOG_ASSERT(mecChoices.size() == localChoiceIndices.size(), "Observation action count does not match for two states with the same observation"); - mecChoices &= localChoiceIndices; - } - } - } + // Get the 'out' state for every MEC with just a single out state + storm::storage::BitVector uniqueOutStates = getEndComponentsWithSingleOutStates(mecs); + + // For each observation of some 'out' state get the intersection of the choices that lead to the corresponding MEC + std::vector mecChoicesPerObservation = getEndComponentChoicesPerObservation(mecs, uniqueOutStates); + - // Filter the observations that have a state that is neither an out state, nor a prob0A or prob1A state - storm::storage::BitVector stateFilter = ~(uniqueOutStates | prob01States.first | prob01States.second); + // Filter the observations that have a state that is neither an out state, nor a prob0A state + storm::storage::BitVector stateFilter = ~(uniqueOutStates | prob01States.first); for (auto const& state : stateFilter) { mecChoicesPerObservation[pomdp.getObservation(state)].clear(); } @@ -108,6 +125,58 @@ namespace storm { return cs.transform(choiceFilter)->template as>(); } + template + storm::storage::BitVector GlobalPomdpMecChoiceEliminator::getEndComponentsWithSingleOutStates(storm::storage::MaximalEndComponentDecomposition const& mecs) const { + storm::storage::BitVector result(pomdp.getNumberOfStates(), false); + for (auto const& mec : mecs) { + boost::optional uniqueOutState = boost::none; + for (auto const& stateActionsPair : mec) { + // Check whether this is an 'out' state, i.e., not all actions stay inside the MEC + if (stateActionsPair.second.size() != pomdp.getNumberOfChoices(stateActionsPair.first)) { + if (uniqueOutState) { + // we already found one out state, so this mec is invalid + uniqueOutState = boost::none; + break; + } else { + uniqueOutState = stateActionsPair.first; + } + } + } + if (uniqueOutState) { + result.set(uniqueOutState.get(), true); + } + } + return result; + } + + template + std::vector GlobalPomdpMecChoiceEliminator::getEndComponentChoicesPerObservation(storm::storage::MaximalEndComponentDecomposition const& mecs, storm::storage::BitVector const& consideredStates) const { + + std::vector result(pomdp.getNrObservations()); + for (auto const& mec : mecs) { + for (auto const& stateActions : mec) { + if (consideredStates.get(stateActions.first)) { + storm::storage::BitVector localChoiceIndices(pomdp.getNumberOfChoices(stateActions.first), false); + uint64_t offset = pomdp.getTransitionMatrix().getRowGroupIndices()[stateActions.first]; + for (auto const& choice : stateActions.second) { + assert(choice >= offset); + localChoiceIndices.set(choice - offset, true); + } + + auto& mecChoices = result[pomdp.getObservation(stateActions.first)]; + if (mecChoices.size() == 0) { + mecChoices = localChoiceIndices; + } else { + STORM_LOG_ASSERT(mecChoices.size() == localChoiceIndices.size(), "Observation action count does not match for two states with the same observation"); + mecChoices &= localChoiceIndices; + } + } + } + } + return result; + } + + template storm::storage::MaximalEndComponentDecomposition GlobalPomdpMecChoiceEliminator::decomposeEndComponents(storm::storage::BitVector const& subsystem, storm::storage::BitVector const& redirectingStates) const { if (redirectingStates.empty()) { diff --git a/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.h b/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.h index 747b3fad0..668ee7cf0 100644 --- a/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.h +++ b/src/storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.h @@ -20,8 +20,11 @@ namespace storm { private: + std::shared_ptr> transformMinReward(storm::logic::EventuallyFormula const& formula) const; std::shared_ptr> transformMax(storm::logic::UntilFormula const& formula) const; storm::storage::MaximalEndComponentDecomposition decomposeEndComponents(storm::storage::BitVector const& subsystem, storm::storage::BitVector const& ignoredStates) const; + storm::storage::BitVector getEndComponentsWithSingleOutStates(storm::storage::MaximalEndComponentDecomposition const& mecs) const; + std::vector getEndComponentChoicesPerObservation(storm::storage::MaximalEndComponentDecomposition const& mecs, storm::storage::BitVector const& consideredStates) const; storm::storage::BitVector checkPropositionalFormula(storm::logic::Formula const& propositionalFormula) const;