3 changed files with 148 additions and 3 deletions
			
			
		- 
					13src/storm-pomdp-cli/storm-pomdp.cpp
 - 
					121src/storm-pomdp/transformer/KnownProbabilityTransformer.cpp
 - 
					17src/storm-pomdp/transformer/KnownProbabilityTransformer.h
 
@ -0,0 +1,121 @@ | 
				
			|||
#include "KnownProbabilityTransformer.h"
 | 
				
			|||
 | 
				
			|||
namespace storm { | 
				
			|||
    namespace pomdp { | 
				
			|||
        namespace transformer { | 
				
			|||
 | 
				
			|||
            template<typename ValueType> | 
				
			|||
            KnownProbabilityTransformer<ValueType>::KnownProbabilityTransformer() { | 
				
			|||
                // Intentionally left empty
 | 
				
			|||
            } | 
				
			|||
 | 
				
			|||
            template<typename ValueType> | 
				
			|||
            std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> | 
				
			|||
            KnownProbabilityTransformer<ValueType>::transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States, | 
				
			|||
                                                              storm::storage::BitVector &prob1States) { | 
				
			|||
                std::map<uint64_t, uint64_t> stateMap; | 
				
			|||
                std::map<uint32_t, uint32_t> observationMap; | 
				
			|||
 | 
				
			|||
                storm::models::sparse::StateLabeling newLabeling(pomdp.getNumberOfStates() - prob0States.getNumberOfSetBits() - prob1States.getNumberOfSetBits() + 2); | 
				
			|||
 | 
				
			|||
                std::vector<uint32_t> newObservations; | 
				
			|||
 | 
				
			|||
                // New state 0 represents all states with probability 1
 | 
				
			|||
                for (auto const &iter : prob1States) { | 
				
			|||
                    stateMap[iter] = 0; | 
				
			|||
 | 
				
			|||
                    std::set<std::string> labelSet = pomdp.getStateLabeling().getLabelsOfState(iter); | 
				
			|||
                    for (auto const &label : labelSet) { | 
				
			|||
                        if (!newLabeling.containsLabel(label)) { | 
				
			|||
                            newLabeling.addLabel(label); | 
				
			|||
                        } | 
				
			|||
                        newLabeling.addLabelToState(label, 0); | 
				
			|||
                    } | 
				
			|||
                } | 
				
			|||
                // New state 1 represents all states with probability 0
 | 
				
			|||
                for (auto const &iter : prob0States) { | 
				
			|||
                    stateMap[iter] = 1; | 
				
			|||
                    for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) { | 
				
			|||
                        if (!newLabeling.containsLabel(label)) { | 
				
			|||
                            newLabeling.addLabel(label); | 
				
			|||
                        } | 
				
			|||
                        newLabeling.addLabelToState(label, 1); | 
				
			|||
                    } | 
				
			|||
                } | 
				
			|||
 | 
				
			|||
                storm::storage::BitVector unknownStates = ~(prob1States | prob0States); | 
				
			|||
                //If there are no states with probability 0 we set the next new state id to be 1, otherwise 2
 | 
				
			|||
                uint64_t newId = prob0States.empty() ? 1 : 2; | 
				
			|||
                uint64_t nextObservation = prob0States.empty() ? 1 : 2; | 
				
			|||
                for (auto const &iter : unknownStates) { | 
				
			|||
                    stateMap[iter] = newId; | 
				
			|||
                    if (observationMap.count(pomdp.getObservation(iter)) == 0) { | 
				
			|||
                        observationMap[pomdp.getObservation(iter)] = nextObservation; | 
				
			|||
                        ++nextObservation; | 
				
			|||
                    } | 
				
			|||
                    for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) { | 
				
			|||
                        if (!newLabeling.containsLabel(label)) { | 
				
			|||
                            newLabeling.addLabel(label); | 
				
			|||
                        } | 
				
			|||
                        newLabeling.addLabelToState(label, newId); | 
				
			|||
                    } | 
				
			|||
                    ++newId; | 
				
			|||
                } | 
				
			|||
 | 
				
			|||
                uint64_t newNrOfStates = pomdp.getNumberOfStates() - (prob1States.getNumberOfSetBits() + prob0States.getNumberOfSetBits()); | 
				
			|||
 | 
				
			|||
                uint64_t currentRow = 0; | 
				
			|||
                uint64_t currentRowGroup = 0; | 
				
			|||
                storm::storage::SparseMatrixBuilder<ValueType> smb(0, 0, 0, false, true); | 
				
			|||
                //new row for prob 1 state
 | 
				
			|||
                smb.newRowGroup(currentRow); | 
				
			|||
                smb.addNextValue(currentRow, 0, storm::utility::one<ValueType>()); | 
				
			|||
                newObservations.push_back(0); | 
				
			|||
                ++currentRowGroup; | 
				
			|||
                ++currentRow; | 
				
			|||
                if (!prob0States.empty()) { | 
				
			|||
                    smb.newRowGroup(currentRow); | 
				
			|||
                    smb.addNextValue(currentRow, 1, storm::utility::one<ValueType>()); | 
				
			|||
                    ++currentRowGroup; | 
				
			|||
                    ++currentRow; | 
				
			|||
                    newObservations.push_back(1); | 
				
			|||
                } | 
				
			|||
 | 
				
			|||
                auto transitionMatrix = pomdp.getTransitionMatrix(); | 
				
			|||
 | 
				
			|||
                for (auto const &iter : unknownStates) { | 
				
			|||
                    smb.newRowGroup(currentRow); | 
				
			|||
                    // First collect all transitions
 | 
				
			|||
                    //auto rowGroup = transitionMatrix.getRowGroup(iter);
 | 
				
			|||
                    for (uint64_t row = 0; row < transitionMatrix.getRowGroupSize(iter); ++row) { | 
				
			|||
                        std::map<uint64_t, ValueType> transitionsInAction; | 
				
			|||
                        for (auto const &entry : transitionMatrix.getRow(iter, row)) { | 
				
			|||
                            // here we use the state mapping to collect all probabilities to get to a state with prob 0/1
 | 
				
			|||
                            transitionsInAction[stateMap[entry.getColumn()]] += entry.getValue(); | 
				
			|||
                        } | 
				
			|||
                        for (auto const &transition : transitionsInAction) { | 
				
			|||
                            smb.addNextValue(currentRow, transition.first, transition.second); | 
				
			|||
                        } | 
				
			|||
                        ++currentRow; | 
				
			|||
                    } | 
				
			|||
                    ++currentRowGroup; | 
				
			|||
                    newObservations.push_back(observationMap[pomdp.getObservation(iter)]); | 
				
			|||
                } | 
				
			|||
 | 
				
			|||
                auto newTransitionMatrix = smb.build(currentRow, newNrOfStates, currentRowGroup); | 
				
			|||
                //STORM_PRINT(newTransitionMatrix)
 | 
				
			|||
                storm::storage::sparse::ModelComponents<ValueType> components(newTransitionMatrix, newLabeling); | 
				
			|||
                components.observabilityClasses = newObservations; | 
				
			|||
 | 
				
			|||
                auto newPomdp = storm::models::sparse::Pomdp<ValueType>(components); | 
				
			|||
 | 
				
			|||
                newPomdp.printModelInformationToStream(std::cout); | 
				
			|||
 | 
				
			|||
                return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(newPomdp); | 
				
			|||
            } | 
				
			|||
 | 
				
			|||
            template | 
				
			|||
            class KnownProbabilityTransformer<double>; | 
				
			|||
        } | 
				
			|||
    } | 
				
			|||
} | 
				
			|||
@ -0,0 +1,17 @@ | 
				
			|||
#include "storm/api/storm.h" | 
				
			|||
#include "storm/models/sparse/Pomdp.h" | 
				
			|||
 | 
				
			|||
namespace storm { | 
				
			|||
    namespace pomdp { | 
				
			|||
        namespace transformer { | 
				
			|||
            template<class ValueType> | 
				
			|||
            class KnownProbabilityTransformer { | 
				
			|||
            public: | 
				
			|||
                KnownProbabilityTransformer(); | 
				
			|||
 | 
				
			|||
                std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> | 
				
			|||
                transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States, storm::storage::BitVector &prob1States); | 
				
			|||
            }; | 
				
			|||
        } | 
				
			|||
    } | 
				
			|||
} | 
				
			|||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue