diff --git a/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.cpp b/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.cpp index 1c9cd87d2..e175139e6 100644 --- a/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.cpp +++ b/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.cpp @@ -1,5 +1,7 @@ #include "storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.h" +#include "storm/utility/vector.h" + namespace storm { namespace transformer { @@ -16,57 +18,89 @@ namespace storm { std::shared_ptr cache; }; - template - std::shared_ptr> ApplyFiniteSchedulerToPomdp::transform() const { - uint64_t nrStates = pomdp.getNumberOfStates(); - std::unordered_map> parameters; - bool nondeterminism = false; - storm::storage::SparseMatrixBuilder smb(nrStates, nrStates, 0, !nondeterminism, false, nrStates); + template + std::unordered_map> ApplyFiniteSchedulerToPomdp::getObservationChoiceWeights() const { + std::unordered_map> res; RationalFunctionConstructor ratFuncConstructor; - - for (uint64_t state = 0; state < nrStates; ++state) { - if (nondeterminism) { - smb.newRowGroup(state); - } + + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { auto observation = pomdp.getObservation(state); - auto it = parameters.find(observation); - std::vector localWeights; - if (it == parameters.end()) { - storm::RationalFunction lastWeight(1); + auto it = res.find(observation); + if (it == res.end()) { + std::vector weights; + storm::RationalFunction lastWeight = storm::utility::one(); for (uint64_t a = 0; a < pomdp.getNumberOfChoices(state) - 1; ++a) { std::string varName = "p" + std::to_string(observation) + "_" + std::to_string(a); - localWeights.push_back(ratFuncConstructor.translate(carl::freshRealVariable(varName))); - lastWeight -= localWeights.back(); + weights.push_back(ratFuncConstructor.translate(carl::freshRealVariable(varName))); + lastWeight -= weights.back(); } - localWeights.push_back(lastWeight); - parameters.emplace(observation, localWeights); - } else { - STORM_LOG_ASSERT(it->second.size() == pomdp.getNumberOfChoices(state), "Number of choices must be equal for every state with same number of actions"); - localWeights = it->second; + weights.push_back(lastWeight); + res.emplace(observation, weights); } + STORM_LOG_ASSERT(it == res.end() || it->second.size() == pomdp.getNumberOfChoices(state), "Number of choices must be equal for every state with same number of actions"); + } + return res; + } + + + + template + std::shared_ptr> ApplyFiniteSchedulerToPomdp::transform() const { + storm::storage::sparse::ModelComponents modelComponents; + + uint64_t nrStates = pomdp.getNumberOfStates(); + std::unordered_map> observationChoiceWeights = getObservationChoiceWeights(); + storm::storage::SparseMatrixBuilder smb(nrStates, nrStates, 0, true); + + for (uint64_t state = 0; state < nrStates; ++state) { + auto const& weights = observationChoiceWeights.at(pomdp.getObservation(state)); std::map weightedTransitions; for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { for (auto const& entry: pomdp.getTransitionMatrix().getRow(state, action)) { auto it = weightedTransitions.find(entry.getColumn()); if (it == weightedTransitions.end()) { - weightedTransitions[entry.getColumn()] = storm::utility::convertNumber(entry.getValue()) * localWeights[action]; //carl::rationalize(entry.getValue()) * localWeights[action]; + weightedTransitions[entry.getColumn()] = storm::utility::convertNumber(entry.getValue()) * weights[action]; } else { - it->second += storm::utility::convertNumber(entry.getValue()) * localWeights[action]; + it->second += storm::utility::convertNumber(entry.getValue()) * weights[action]; } } } - for (auto const& entry : weightedTransitions) { smb.addNextValue(state, entry.first, entry.second); } } - - // TODO rewards. - - storm::storage::sparse::ModelComponents modelComponents(smb.build(),pomdp.getStateLabeling()); + modelComponents.transitionMatrix = smb.build(); + + for (auto const& pomdpRewardModel : pomdp.getRewardModels()) { + std::vector stateRewards; + + if (pomdpRewardModel.second.hasStateRewards()) { + stateRewards = storm::utility::vector::convertNumericVector(pomdpRewardModel.second.getStateRewardVector()); + } else { + stateRewards.resize(nrStates, storm::utility::zero()); + } + if (pomdpRewardModel.second.hasStateActionRewards()) { + std::vector pomdpActionRewards = pomdpRewardModel.second.getStateActionRewardVector(); + for (uint64_t state = 0; state < nrStates; ++state) { + auto& stateReward = stateRewards[state]; + auto const& weights = observationChoiceWeights.at(pomdp.getObservation(state)); + uint64_t offset = pomdp.getTransitionMatrix().getRowGroupIndices()[state]; + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(state); ++action) { + if (!storm::utility::isZero(pomdpActionRewards[offset + action])) { + stateReward += storm::utility::convertNumber(pomdpActionRewards[offset + action]) * weights[action]; + } + } + } + } + storm::models::sparse::StandardRewardModel rewardModel(std::move(stateRewards)); + modelComponents.rewardModels.emplace(pomdpRewardModel.first, std::move(rewardModel)); + } + + modelComponents.stateLabeling = pomdp.getStateLabeling(); + return std::make_shared>(modelComponents); - + } template class ApplyFiniteSchedulerToPomdp; diff --git a/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.h b/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.h index 42fd45d94..d6b4967ce 100644 --- a/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.h +++ b/src/storm-pomdp/transformer/ApplyFiniteSchedulerToPomdp.h @@ -17,6 +17,12 @@ namespace storm { } std::shared_ptr> transform() const; + + private: + + + std::unordered_map> getObservationChoiceWeights() const; + storm::models::sparse::Pomdp const& pomdp; }; } diff --git a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp index bc1a47494..db550dd14 100644 --- a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp +++ b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp @@ -31,11 +31,11 @@ namespace storm { template storm::storage::SparseMatrix PomdpMemoryUnfolder::transformTransitions() const { storm::storage::SparseMatrix const& origTransitions = pomdp.getTransitionMatrix(); - storm::storage::SparseMatrixBuilder builder(pomdp.getNumberOfStates() * numMemoryStates * numMemoryStates, + storm::storage::SparseMatrixBuilder builder(pomdp.getNumberOfChoices() * numMemoryStates * numMemoryStates, pomdp.getNumberOfStates() * numMemoryStates, origTransitions.getEntryCount() * numMemoryStates * numMemoryStates, true, - false, + true, pomdp.getNumberOfStates() * numMemoryStates); uint64_t row = 0; @@ -96,7 +96,7 @@ namespace storm { } if (rewardModel.hasStateActionRewards()) { actionRewards = std::vector(); - stateRewards->reserve(pomdp.getNumberOfStates() * numMemoryStates * numMemoryStates); + actionRewards->reserve(pomdp.getNumberOfStates() * numMemoryStates * numMemoryStates); for (uint64_t modelState = 0; modelState < pomdp.getNumberOfStates(); ++modelState) { for (uint32_t memState = 0; memState < numMemoryStates; ++memState) { for (uint64_t origRow = pomdp.getTransitionMatrix().getRowGroupIndices()[modelState]; origRow < pomdp.getTransitionMatrix().getRowGroupIndices()[modelState + 1]; ++origRow) { @@ -108,7 +108,7 @@ namespace storm { } } } - STORM_LOG_THROW(rewardModel.hasTransitionRewards(), storm::exceptions::NotSupportedException, "Transition rewards are currently not supported."); + STORM_LOG_THROW(!rewardModel.hasTransitionRewards(), storm::exceptions::NotSupportedException, "Transition rewards are currently not supported."); return storm::models::sparse::StandardRewardModel(std::move(stateRewards), std::move(actionRewards)); }