Browse Source

Picking pivot state with quantitative information available now considers max diff

tempestpy_adaptions
dehnert 8 years ago
parent
commit
ef441f525a
  1. 88
      src/storm/abstraction/MenuGameRefiner.cpp

88
src/storm/abstraction/MenuGameRefiner.cpp

@ -21,8 +21,8 @@ namespace storm {
abstractor.get().refine(predicates); abstractor.get().refine(predicates);
} }
template<storm::dd::DdType Type>
storm::dd::Bdd<Type> pickPivotState(storm::dd::Bdd<Type> const& initialStates, storm::dd::Bdd<Type> const& transitions, std::set<storm::expressions::Variable> const& rowVariables, std::set<storm::expressions::Variable> const& columnVariables, storm::dd::Bdd<Type> const& pivotStates) {
template<storm::dd::DdType Type, typename ValueType>
storm::dd::Bdd<Type> pickPivotState(storm::dd::Bdd<Type> const& initialStates, storm::dd::Bdd<Type> const& transitions, std::set<storm::expressions::Variable> const& rowVariables, std::set<storm::expressions::Variable> const& columnVariables, storm::dd::Bdd<Type> const& pivotStates, boost::optional<QuantitativeResultMinMax<Type, ValueType>> const& quantitativeResult = boost::none) {
// Perform a BFS and pick the first pivot state we encounter. // Perform a BFS and pick the first pivot state we encounter.
storm::dd::Bdd<Type> pivotState; storm::dd::Bdd<Type> pivotState;
@ -40,9 +40,17 @@ namespace storm {
frontierPivotStates = frontier && pivotStates; frontierPivotStates = frontier && pivotStates;
if (!frontierPivotStates.isZero()) { if (!frontierPivotStates.isZero()) {
STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
pivotState = frontierPivotStates.existsAbstractRepresentative(rowVariables);
foundPivotState = true;
if (quantitativeResult) {
storm::dd::Add<Type, ValueType> frontierPivotStatesAdd = frontierPivotStates.template toAdd<ValueType>();
storm::dd::Add<Type, ValueType> diff = frontierPivotStatesAdd * quantitativeResult.get().max.values - frontierPivotStatesAdd * quantitativeResult.get().min.values;
pivotState = diff.maxAbstractRepresentative(rowVariables);
STORM_LOG_TRACE("Picked pivot state with difference " << diff.getMax() << " from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
foundPivotState = true;
} else {
pivotState = frontierPivotStates.existsAbstractRepresentative(rowVariables);
STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
foundPivotState = true;
}
} }
++level; ++level;
} }
@ -59,32 +67,6 @@ namespace storm {
storm::dd::Add<Type, ValueType> player1ChoiceAsAdd = player1Choice.template toAdd<ValueType>(); storm::dd::Add<Type, ValueType> player1ChoiceAsAdd = player1Choice.template toAdd<ValueType>();
auto pl1It = player1ChoiceAsAdd.begin(); auto pl1It = player1ChoiceAsAdd.begin();
uint_fast64_t player1Index = abstractionInformation.decodePlayer1Choice((*pl1It).first, abstractionInformation.getPlayer1VariableCount()); uint_fast64_t player1Index = abstractionInformation.decodePlayer1Choice((*pl1It).first, abstractionInformation.getPlayer1VariableCount());
#ifdef LOCAL_DEBUG
std::cout << "command index " << commandIndex << std::endl;
std::cout << program.get() << std::endl;
for (auto stateValue : pivotState.template toAdd<ValueType>()) {
std::stringstream stateName;
stateName << "\tpl1_";
for (auto const& var : currentGame->getRowVariables()) {
std::cout << "var " << var.getName() << std::endl;
if (stateValue.first.getBooleanValue(var)) {
stateName << "1";
} else {
stateName << "0";
}
}
std::cout << "pivot is " << stateName.str() << std::endl;
}
#endif
// storm::abstraction::prism::AbstractCommand<Type, ValueType>& abstractCommand = modules.front().getCommands()[commandIndex];
// storm::prism::Command const& concreteCommand = abstractCommand.getConcreteCommand();
#ifdef LOCAL_DEBUG
player1Choice.template toAdd<ValueType>().exportToDot("pl1choice_ref.dot");
std::cout << concreteCommand << std::endl;
(currentGame->getTransitionMatrix() * player1Choice.template toAdd<ValueType>()).exportToDot("cuttrans.dot");
#endif
// Check whether there are bottom states in the game and whether one of the choices actually picks the // Check whether there are bottom states in the game and whether one of the choices actually picks the
// bottom state as the successor. // bottom state as the successor.
@ -100,33 +82,11 @@ namespace storm {
} else { } else {
STORM_LOG_TRACE("No bottom state successor. Deriving a new predicate using weakest precondition."); STORM_LOG_TRACE("No bottom state successor. Deriving a new predicate using weakest precondition.");
#ifdef LOCAL_DEBUG
lowerChoice.template toAdd<ValueType>().exportToDot("lowerchoice_ref.dot");
upperChoice.template toAdd<ValueType>().exportToDot("upperchoice_ref.dot");
#endif
// Decode both choices to explicit mappings. // Decode both choices to explicit mappings.
#ifdef LOCAL_DEBUG
std::cout << "lower" << std::endl;
#endif
std::map<uint_fast64_t, storm::storage::BitVector> lowerChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(lowerChoice); std::map<uint_fast64_t, storm::storage::BitVector> lowerChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(lowerChoice);
#ifdef LOCAL_DEBUG
std::cout << "upper" << std::endl;
#endif
std::map<uint_fast64_t, storm::storage::BitVector> upperChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(upperChoice); std::map<uint_fast64_t, storm::storage::BitVector> upperChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(upperChoice);
STORM_LOG_ASSERT(lowerChoiceUpdateToSuccessorMapping.size() == upperChoiceUpdateToSuccessorMapping.size(), "Mismatching sizes after decode (" << lowerChoiceUpdateToSuccessorMapping.size() << " vs. " << upperChoiceUpdateToSuccessorMapping.size() << ")."); STORM_LOG_ASSERT(lowerChoiceUpdateToSuccessorMapping.size() == upperChoiceUpdateToSuccessorMapping.size(), "Mismatching sizes after decode (" << lowerChoiceUpdateToSuccessorMapping.size() << " vs. " << upperChoiceUpdateToSuccessorMapping.size() << ").");
#ifdef LOCAL_DEBUG
std::cout << "lower" << std::endl;
for (auto const& entry : lowerChoiceUpdateToSuccessorMapping) {
std::cout << entry.first << " -> " << entry.second << std::endl;
}
std::cout << "upper" << std::endl;
for (auto const& entry : upperChoiceUpdateToSuccessorMapping) {
std::cout << entry.first << " -> " << entry.second << std::endl;
}
#endif
// Now go through the mappings and find points of deviation. Currently, we take the first deviation. // Now go through the mappings and find points of deviation. Currently, we take the first deviation.
storm::expressions::Expression newPredicate; storm::expressions::Expression newPredicate;
auto lowerIt = lowerChoiceUpdateToSuccessorMapping.begin(); auto lowerIt = lowerChoiceUpdateToSuccessorMapping.begin();
@ -135,16 +95,11 @@ namespace storm {
for (; lowerIt != lowerIte; ++lowerIt, ++upperIt) { for (; lowerIt != lowerIte; ++lowerIt, ++upperIt) {
STORM_LOG_ASSERT(lowerIt->first == upperIt->first, "Update indices mismatch."); STORM_LOG_ASSERT(lowerIt->first == upperIt->first, "Update indices mismatch.");
uint_fast64_t updateIndex = lowerIt->first; uint_fast64_t updateIndex = lowerIt->first;
#ifdef LOCAL_DEBUG
std::cout << "update idx " << updateIndex << std::endl;
#endif
bool deviates = lowerIt->second != upperIt->second; bool deviates = lowerIt->second != upperIt->second;
if (deviates) { if (deviates) {
for (uint_fast64_t predicateIndex = 0; predicateIndex < lowerIt->second.size(); ++predicateIndex) { for (uint_fast64_t predicateIndex = 0; predicateIndex < lowerIt->second.size(); ++predicateIndex) {
if (lowerIt->second.get(predicateIndex) != upperIt->second.get(predicateIndex)) { if (lowerIt->second.get(predicateIndex) != upperIt->second.get(predicateIndex)) {
// Now we know the point of the deviation (command, update, predicate). // Now we know the point of the deviation (command, update, predicate).
std::cout << "ref" << std::endl;
std::cout << abstractionInformation.getPredicateByIndex(predicateIndex) << std::endl;
newPredicate = abstractionInformation.getPredicateByIndex(predicateIndex).substitute(abstractor.get().getVariableUpdates(player1Index, updateIndex)).simplify(); newPredicate = abstractionInformation.getPredicateByIndex(predicateIndex).substitute(abstractor.get().getVariableUpdates(player1Index, updateIndex)).simplify();
break; break;
} }
@ -208,7 +163,7 @@ namespace storm {
STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to proceed without pivot state candidates."); STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to proceed without pivot state candidates.");
// Now that we have the pivot state candidates, we need to pick one. // Now that we have the pivot state candidates, we need to pick one.
storm::dd::Bdd<Type> pivotState = pickPivotState<Type>(game.getInitialStates(), reachableTransitions, game.getRowVariables(), game.getColumnVariables(), pivotStates);
storm::dd::Bdd<Type> pivotState = pickPivotState<Type, ValueType>(game.getInitialStates(), reachableTransitions, game.getRowVariables(), game.getColumnVariables(), pivotStates);
// Compute the lower and the upper choice for the pivot state. // Compute the lower and the upper choice for the pivot state.
std::set<storm::expressions::Variable> variablesToAbstract = game.getNondeterminismVariables(); std::set<storm::expressions::Variable> variablesToAbstract = game.getNondeterminismVariables();
@ -283,7 +238,7 @@ namespace storm {
STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to refine without pivot state candidates."); STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to refine without pivot state candidates.");
// Now that we have the pivot state candidates, we need to pick one. // Now that we have the pivot state candidates, we need to pick one.
storm::dd::Bdd<Type> pivotState = pickPivotState<Type>(game.getInitialStates(), reachableTransitions, game.getRowVariables(), game.getColumnVariables(), pivotStates);
storm::dd::Bdd<Type> pivotState = pickPivotState<Type, ValueType>(game.getInitialStates(), reachableTransitions, game.getRowVariables(), game.getColumnVariables(), pivotStates, quantitativeResult);
// Compute the lower and the upper choice for the pivot state. // Compute the lower and the upper choice for the pivot state.
std::set<storm::expressions::Variable> variablesToAbstract = game.getNondeterminismVariables(); std::set<storm::expressions::Variable> variablesToAbstract = game.getNondeterminismVariables();
@ -315,8 +270,8 @@ namespace storm {
STORM_LOG_ASSERT(false, "Did not find choices from which to derive predicates."); STORM_LOG_ASSERT(false, "Did not find choices from which to derive predicates.");
} }
} }
return true;
} }
template<storm::dd::DdType Type, typename ValueType> template<storm::dd::DdType Type, typename ValueType>
bool MenuGameRefiner<Type, ValueType>::performRefinement(std::vector<storm::expressions::Expression> const& predicates) const { bool MenuGameRefiner<Type, ValueType>::performRefinement(std::vector<storm::expressions::Expression> const& predicates) const {
@ -331,6 +286,8 @@ namespace storm {
// Check which of the atoms are redundant in the sense that they are equivalent to a predicate we already have. // Check which of the atoms are redundant in the sense that they are equivalent to a predicate we already have.
for (auto const& atom : atoms) { for (auto const& atom : atoms) {
// Check whether the newly found atom is equivalent to an atom we already have in the predicate
// set or in the set that is to be added.
bool addAtom = true; bool addAtom = true;
for (auto const& oldPredicate : abstractionInformation.getPredicates()) { for (auto const& oldPredicate : abstractionInformation.getPredicates()) {
if (equivalenceChecker.areEquivalent(atom, oldPredicate)) { if (equivalenceChecker.areEquivalent(atom, oldPredicate)) {
@ -338,6 +295,12 @@ namespace storm {
break; break;
} }
} }
for (auto const& addedAtom : cleanedAtoms) {
if (equivalenceChecker.areEquivalent(addedAtom, atom)) {
addAtom = false;
break;
}
}
if (addAtom) { if (addAtom) {
cleanedAtoms.push_back(atom); cleanedAtoms.push_back(atom);
@ -347,8 +310,11 @@ namespace storm {
abstractor.get().refine(cleanedAtoms); abstractor.get().refine(cleanedAtoms);
} else { } else {
// If no splitting of the predicates is required, just forward the refinement request to the abstractor.
abstractor.get().refine(predicates); abstractor.get().refine(predicates);
} }
return true;
} }
template class MenuGameRefiner<storm::dd::DdType::CUDD, double>; template class MenuGameRefiner<storm::dd::DdType::CUDD, double>;

Loading…
Cancel
Save