3 changed files with 148 additions and 3 deletions
			
			
		- 
					13src/storm-pomdp-cli/storm-pomdp.cpp
- 
					121src/storm-pomdp/transformer/KnownProbabilityTransformer.cpp
- 
					17src/storm-pomdp/transformer/KnownProbabilityTransformer.h
| @ -0,0 +1,121 @@ | |||
| #include "KnownProbabilityTransformer.h"
 | |||
| 
 | |||
| namespace storm { | |||
|     namespace pomdp { | |||
|         namespace transformer { | |||
| 
 | |||
|             template<typename ValueType> | |||
|             KnownProbabilityTransformer<ValueType>::KnownProbabilityTransformer() { | |||
|                 // Intentionally left empty
 | |||
|             } | |||
| 
 | |||
|             template<typename ValueType> | |||
|             std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> | |||
|             KnownProbabilityTransformer<ValueType>::transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States, | |||
|                                                               storm::storage::BitVector &prob1States) { | |||
|                 std::map<uint64_t, uint64_t> stateMap; | |||
|                 std::map<uint32_t, uint32_t> observationMap; | |||
| 
 | |||
|                 storm::models::sparse::StateLabeling newLabeling(pomdp.getNumberOfStates() - prob0States.getNumberOfSetBits() - prob1States.getNumberOfSetBits() + 2); | |||
| 
 | |||
|                 std::vector<uint32_t> newObservations; | |||
| 
 | |||
|                 // New state 0 represents all states with probability 1
 | |||
|                 for (auto const &iter : prob1States) { | |||
|                     stateMap[iter] = 0; | |||
| 
 | |||
|                     std::set<std::string> labelSet = pomdp.getStateLabeling().getLabelsOfState(iter); | |||
|                     for (auto const &label : labelSet) { | |||
|                         if (!newLabeling.containsLabel(label)) { | |||
|                             newLabeling.addLabel(label); | |||
|                         } | |||
|                         newLabeling.addLabelToState(label, 0); | |||
|                     } | |||
|                 } | |||
|                 // New state 1 represents all states with probability 0
 | |||
|                 for (auto const &iter : prob0States) { | |||
|                     stateMap[iter] = 1; | |||
|                     for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) { | |||
|                         if (!newLabeling.containsLabel(label)) { | |||
|                             newLabeling.addLabel(label); | |||
|                         } | |||
|                         newLabeling.addLabelToState(label, 1); | |||
|                     } | |||
|                 } | |||
| 
 | |||
|                 storm::storage::BitVector unknownStates = ~(prob1States | prob0States); | |||
|                 //If there are no states with probability 0 we set the next new state id to be 1, otherwise 2
 | |||
|                 uint64_t newId = prob0States.empty() ? 1 : 2; | |||
|                 uint64_t nextObservation = prob0States.empty() ? 1 : 2; | |||
|                 for (auto const &iter : unknownStates) { | |||
|                     stateMap[iter] = newId; | |||
|                     if (observationMap.count(pomdp.getObservation(iter)) == 0) { | |||
|                         observationMap[pomdp.getObservation(iter)] = nextObservation; | |||
|                         ++nextObservation; | |||
|                     } | |||
|                     for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) { | |||
|                         if (!newLabeling.containsLabel(label)) { | |||
|                             newLabeling.addLabel(label); | |||
|                         } | |||
|                         newLabeling.addLabelToState(label, newId); | |||
|                     } | |||
|                     ++newId; | |||
|                 } | |||
| 
 | |||
|                 uint64_t newNrOfStates = pomdp.getNumberOfStates() - (prob1States.getNumberOfSetBits() + prob0States.getNumberOfSetBits()); | |||
| 
 | |||
|                 uint64_t currentRow = 0; | |||
|                 uint64_t currentRowGroup = 0; | |||
|                 storm::storage::SparseMatrixBuilder<ValueType> smb(0, 0, 0, false, true); | |||
|                 //new row for prob 1 state
 | |||
|                 smb.newRowGroup(currentRow); | |||
|                 smb.addNextValue(currentRow, 0, storm::utility::one<ValueType>()); | |||
|                 newObservations.push_back(0); | |||
|                 ++currentRowGroup; | |||
|                 ++currentRow; | |||
|                 if (!prob0States.empty()) { | |||
|                     smb.newRowGroup(currentRow); | |||
|                     smb.addNextValue(currentRow, 1, storm::utility::one<ValueType>()); | |||
|                     ++currentRowGroup; | |||
|                     ++currentRow; | |||
|                     newObservations.push_back(1); | |||
|                 } | |||
| 
 | |||
|                 auto transitionMatrix = pomdp.getTransitionMatrix(); | |||
| 
 | |||
|                 for (auto const &iter : unknownStates) { | |||
|                     smb.newRowGroup(currentRow); | |||
|                     // First collect all transitions
 | |||
|                     //auto rowGroup = transitionMatrix.getRowGroup(iter);
 | |||
|                     for (uint64_t row = 0; row < transitionMatrix.getRowGroupSize(iter); ++row) { | |||
|                         std::map<uint64_t, ValueType> transitionsInAction; | |||
|                         for (auto const &entry : transitionMatrix.getRow(iter, row)) { | |||
|                             // here we use the state mapping to collect all probabilities to get to a state with prob 0/1
 | |||
|                             transitionsInAction[stateMap[entry.getColumn()]] += entry.getValue(); | |||
|                         } | |||
|                         for (auto const &transition : transitionsInAction) { | |||
|                             smb.addNextValue(currentRow, transition.first, transition.second); | |||
|                         } | |||
|                         ++currentRow; | |||
|                     } | |||
|                     ++currentRowGroup; | |||
|                     newObservations.push_back(observationMap[pomdp.getObservation(iter)]); | |||
|                 } | |||
| 
 | |||
|                 auto newTransitionMatrix = smb.build(currentRow, newNrOfStates, currentRowGroup); | |||
|                 //STORM_PRINT(newTransitionMatrix)
 | |||
|                 storm::storage::sparse::ModelComponents<ValueType> components(newTransitionMatrix, newLabeling); | |||
|                 components.observabilityClasses = newObservations; | |||
| 
 | |||
|                 auto newPomdp = storm::models::sparse::Pomdp<ValueType>(components); | |||
| 
 | |||
|                 newPomdp.printModelInformationToStream(std::cout); | |||
| 
 | |||
|                 return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(newPomdp); | |||
|             } | |||
| 
 | |||
|             template | |||
|             class KnownProbabilityTransformer<double>; | |||
|         } | |||
|     } | |||
| } | |||
| @ -0,0 +1,17 @@ | |||
| #include "storm/api/storm.h" | |||
| #include "storm/models/sparse/Pomdp.h" | |||
| 
 | |||
| namespace storm { | |||
|     namespace pomdp { | |||
|         namespace transformer { | |||
|             template<class ValueType> | |||
|             class KnownProbabilityTransformer { | |||
|             public: | |||
|                 KnownProbabilityTransformer(); | |||
| 
 | |||
|                 std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> | |||
|                 transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States, storm::storage::BitVector &prob1States); | |||
|             }; | |||
|         } | |||
|     } | |||
| } | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue