You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

123 lines
5.9 KiB

#include "KnownProbabilityTransformer.h"
namespace storm {
namespace pomdp {
namespace transformer {
template<typename ValueType>
KnownProbabilityTransformer<ValueType>::KnownProbabilityTransformer() {
// Intentionally left empty
}
template<typename ValueType>
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>>
KnownProbabilityTransformer<ValueType>::transform(storm::models::sparse::Pomdp<ValueType> const &pomdp, storm::storage::BitVector &prob0States,
storm::storage::BitVector &prob1States) {
std::map<uint64_t, uint64_t> stateMap;
std::map<uint32_t, uint32_t> observationMap;
uint64_t nrNewStates = prob0States.empty() ? 1 : 2;
storm::models::sparse::StateLabeling newLabeling(pomdp.getNumberOfStates() - prob0States.getNumberOfSetBits() - prob1States.getNumberOfSetBits() + nrNewStates);
std::vector<uint32_t> newObservations;
// New state 0 represents all states with probability 1
for (auto const &iter : prob1States) {
stateMap[iter] = 0;
std::set<std::string> labelSet = pomdp.getStateLabeling().getLabelsOfState(iter);
for (auto const &label : labelSet) {
if (!newLabeling.containsLabel(label)) {
newLabeling.addLabel(label);
}
newLabeling.addLabelToState(label, 0);
}
}
// New state 1 represents all states with probability 0
for (auto const &iter : prob0States) {
stateMap[iter] = 1;
for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) {
if (!newLabeling.containsLabel(label)) {
newLabeling.addLabel(label);
}
newLabeling.addLabelToState(label, 1);
}
}
storm::storage::BitVector unknownStates = ~(prob1States | prob0States);
//If there are no states with probability 0 we set the next new state id to be 1, otherwise 2
uint64_t newId = prob0States.empty() ? 1 : 2;
uint64_t nextObservation = prob0States.empty() ? 1 : 2;
for (auto const &iter : unknownStates) {
stateMap[iter] = newId;
if (observationMap.count(pomdp.getObservation(iter)) == 0) {
observationMap[pomdp.getObservation(iter)] = nextObservation;
++nextObservation;
}
for (auto const &label : pomdp.getStateLabeling().getLabelsOfState(iter)) {
if (!newLabeling.containsLabel(label)) {
newLabeling.addLabel(label);
}
newLabeling.addLabelToState(label, newId);
}
++newId;
}
uint64_t newNrOfStates = pomdp.getNumberOfStates() - (prob1States.getNumberOfSetBits() + prob0States.getNumberOfSetBits());
uint64_t currentRow = 0;
uint64_t currentRowGroup = 0;
storm::storage::SparseMatrixBuilder<ValueType> smb(0, 0, 0, false, true);
//new row for prob 1 state
smb.newRowGroup(currentRow);
smb.addNextValue(currentRow, 0, storm::utility::one<ValueType>());
newObservations.push_back(0);
++currentRowGroup;
++currentRow;
if (!prob0States.empty()) {
smb.newRowGroup(currentRow);
smb.addNextValue(currentRow, 1, storm::utility::one<ValueType>());
++currentRowGroup;
++currentRow;
newObservations.push_back(1);
}
auto transitionMatrix = pomdp.getTransitionMatrix();
for (auto const &iter : unknownStates) {
smb.newRowGroup(currentRow);
// First collect all transitions
//auto rowGroup = transitionMatrix.getRowGroup(iter);
for (uint64_t row = 0; row < transitionMatrix.getRowGroupSize(iter); ++row) {
std::map<uint64_t, ValueType> transitionsInAction;
for (auto const &entry : transitionMatrix.getRow(iter, row)) {
// here we use the state mapping to collect all probabilities to get to a state with prob 0/1
transitionsInAction[stateMap[entry.getColumn()]] += entry.getValue();
}
for (auto const &transition : transitionsInAction) {
smb.addNextValue(currentRow, transition.first, transition.second);
}
++currentRow;
}
++currentRowGroup;
newObservations.push_back(observationMap[pomdp.getObservation(iter)]);
}
auto newTransitionMatrix = smb.build(currentRow, newNrOfStates, currentRowGroup);
//STORM_PRINT(newTransitionMatrix)
storm::storage::sparse::ModelComponents<ValueType> components(newTransitionMatrix, newLabeling);
components.observabilityClasses = newObservations;
auto newPomdp = storm::models::sparse::Pomdp<ValueType>(components);
newPomdp.printModelInformationToStream(std::cout);
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(newPomdp);
}
template
class KnownProbabilityTransformer<double>;
}
}
}