3 changed files with 148 additions and 3 deletions
			
			
		- 
					13src/storm-pomdp-cli/storm-pomdp.cpp
- 
					121src/storm-pomdp/transformer/KnownProbabilityTransformer.cpp
- 
					17src/storm-pomdp/transformer/KnownProbabilityTransformer.h
| @ -0,0 +1,121 @@ | |||||
|  | #include "KnownProbabilityTransformer.h"
 | ||||
|  | 
 | ||||
|  | namespace storm { | ||||
|  |     namespace pomdp { | ||||
|  |         namespace transformer { | ||||
|  | 
 | ||||
|  |             template<typename ValueType> | ||||
|  |             KnownProbabilityTransformer<ValueType>::KnownProbabilityTransformer() { | ||||
|  |                 // Intentionally left empty
 | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             template<typename ValueType> | ||||
|  |             std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> | ||||
|  |             KnownProbabilityTransformer<ValueType>::transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States, | ||||
|  |                                                               storm::storage::BitVector &prob1States) { | ||||
|  |                 std::map<uint64_t, uint64_t> stateMap; | ||||
|  |                 std::map<uint32_t, uint32_t> observationMap; | ||||
|  | 
 | ||||
|  |                 storm::models::sparse::StateLabeling newLabeling(pomdp.getNumberOfStates() - prob0States.getNumberOfSetBits() - prob1States.getNumberOfSetBits() + 2); | ||||
|  | 
 | ||||
|  |                 std::vector<uint32_t> newObservations; | ||||
|  | 
 | ||||
|  |                 // New state 0 represents all states with probability 1
 | ||||
|  |                 for (auto const &iter : prob1States) { | ||||
|  |                     stateMap[iter] = 0; | ||||
|  | 
 | ||||
|  |                     std::set<std::string> labelSet = pomdp.getStateLabeling().getLabelsOfState(iter); | ||||
|  |                     for (auto const &label : labelSet) { | ||||
|  |                         if (!newLabeling.containsLabel(label)) { | ||||
|  |                             newLabeling.addLabel(label); | ||||
|  |                         } | ||||
|  |                         newLabeling.addLabelToState(label, 0); | ||||
|  |                     } | ||||
|  |                 } | ||||
|  |                 // New state 1 represents all states with probability 0
 | ||||
|  |                 for (auto const &iter : prob0States) { | ||||
|  |                     stateMap[iter] = 1; | ||||
|  |                     for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) { | ||||
|  |                         if (!newLabeling.containsLabel(label)) { | ||||
|  |                             newLabeling.addLabel(label); | ||||
|  |                         } | ||||
|  |                         newLabeling.addLabelToState(label, 1); | ||||
|  |                     } | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 storm::storage::BitVector unknownStates = ~(prob1States | prob0States); | ||||
|  |                 //If there are no states with probability 0 we set the next new state id to be 1, otherwise 2
 | ||||
|  |                 uint64_t newId = prob0States.empty() ? 1 : 2; | ||||
|  |                 uint64_t nextObservation = prob0States.empty() ? 1 : 2; | ||||
|  |                 for (auto const &iter : unknownStates) { | ||||
|  |                     stateMap[iter] = newId; | ||||
|  |                     if (observationMap.count(pomdp.getObservation(iter)) == 0) { | ||||
|  |                         observationMap[pomdp.getObservation(iter)] = nextObservation; | ||||
|  |                         ++nextObservation; | ||||
|  |                     } | ||||
|  |                     for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) { | ||||
|  |                         if (!newLabeling.containsLabel(label)) { | ||||
|  |                             newLabeling.addLabel(label); | ||||
|  |                         } | ||||
|  |                         newLabeling.addLabelToState(label, newId); | ||||
|  |                     } | ||||
|  |                     ++newId; | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 uint64_t newNrOfStates = pomdp.getNumberOfStates() - (prob1States.getNumberOfSetBits() + prob0States.getNumberOfSetBits()); | ||||
|  | 
 | ||||
|  |                 uint64_t currentRow = 0; | ||||
|  |                 uint64_t currentRowGroup = 0; | ||||
|  |                 storm::storage::SparseMatrixBuilder<ValueType> smb(0, 0, 0, false, true); | ||||
|  |                 //new row for prob 1 state
 | ||||
|  |                 smb.newRowGroup(currentRow); | ||||
|  |                 smb.addNextValue(currentRow, 0, storm::utility::one<ValueType>()); | ||||
|  |                 newObservations.push_back(0); | ||||
|  |                 ++currentRowGroup; | ||||
|  |                 ++currentRow; | ||||
|  |                 if (!prob0States.empty()) { | ||||
|  |                     smb.newRowGroup(currentRow); | ||||
|  |                     smb.addNextValue(currentRow, 1, storm::utility::one<ValueType>()); | ||||
|  |                     ++currentRowGroup; | ||||
|  |                     ++currentRow; | ||||
|  |                     newObservations.push_back(1); | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 auto transitionMatrix = pomdp.getTransitionMatrix(); | ||||
|  | 
 | ||||
|  |                 for (auto const &iter : unknownStates) { | ||||
|  |                     smb.newRowGroup(currentRow); | ||||
|  |                     // First collect all transitions
 | ||||
|  |                     //auto rowGroup = transitionMatrix.getRowGroup(iter);
 | ||||
|  |                     for (uint64_t row = 0; row < transitionMatrix.getRowGroupSize(iter); ++row) { | ||||
|  |                         std::map<uint64_t, ValueType> transitionsInAction; | ||||
|  |                         for (auto const &entry : transitionMatrix.getRow(iter, row)) { | ||||
|  |                             // here we use the state mapping to collect all probabilities to get to a state with prob 0/1
 | ||||
|  |                             transitionsInAction[stateMap[entry.getColumn()]] += entry.getValue(); | ||||
|  |                         } | ||||
|  |                         for (auto const &transition : transitionsInAction) { | ||||
|  |                             smb.addNextValue(currentRow, transition.first, transition.second); | ||||
|  |                         } | ||||
|  |                         ++currentRow; | ||||
|  |                     } | ||||
|  |                     ++currentRowGroup; | ||||
|  |                     newObservations.push_back(observationMap[pomdp.getObservation(iter)]); | ||||
|  |                 } | ||||
|  | 
 | ||||
|  |                 auto newTransitionMatrix = smb.build(currentRow, newNrOfStates, currentRowGroup); | ||||
|  |                 //STORM_PRINT(newTransitionMatrix)
 | ||||
|  |                 storm::storage::sparse::ModelComponents<ValueType> components(newTransitionMatrix, newLabeling); | ||||
|  |                 components.observabilityClasses = newObservations; | ||||
|  | 
 | ||||
|  |                 auto newPomdp = storm::models::sparse::Pomdp<ValueType>(components); | ||||
|  | 
 | ||||
|  |                 newPomdp.printModelInformationToStream(std::cout); | ||||
|  | 
 | ||||
|  |                 return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(newPomdp); | ||||
|  |             } | ||||
|  | 
 | ||||
|  |             template | ||||
|  |             class KnownProbabilityTransformer<double>; | ||||
|  |         } | ||||
|  |     } | ||||
|  | } | ||||
| @ -0,0 +1,17 @@ | |||||
|  | #include "storm/api/storm.h" | ||||
|  | #include "storm/models/sparse/Pomdp.h" | ||||
|  | 
 | ||||
|  | namespace storm { | ||||
|  |     namespace pomdp { | ||||
|  |         namespace transformer { | ||||
|  |             template<class ValueType> | ||||
|  |             class KnownProbabilityTransformer { | ||||
|  |             public: | ||||
|  |                 KnownProbabilityTransformer(); | ||||
|  | 
 | ||||
|  |                 std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> | ||||
|  |                 transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States, storm::storage::BitVector &prob1States); | ||||
|  |             }; | ||||
|  |         } | ||||
|  |     } | ||||
|  | } | ||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue