From 98162d1d7eaeae32289ab532c95a10810aaf56ab Mon Sep 17 00:00:00 2001 From: sjunges Date: Mon, 28 Sep 2015 16:53:55 +0200 Subject: [PATCH] interface for rew. model extended for reinforcement learning Former-commit-id: b69474fc4fc324a5d54168926ae264823fad683f --- src/models/sparse/StandardRewardModel.cpp | 18 ++++++++++++++++++ src/models/sparse/StandardRewardModel.h | 9 ++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/models/sparse/StandardRewardModel.cpp b/src/models/sparse/StandardRewardModel.cpp index 5788dae10..befc06aa6 100644 --- a/src/models/sparse/StandardRewardModel.cpp +++ b/src/models/sparse/StandardRewardModel.cpp @@ -37,11 +37,13 @@ namespace storm { template std::vector const& StandardRewardModel::getStateRewardVector() const { + assert(this->hasStateRewards()); return this->optionalStateRewardVector.get(); } template std::vector& StandardRewardModel::getStateRewardVector() { + assert(this->hasStateRewards()); return this->optionalStateRewardVector.get(); } @@ -49,6 +51,13 @@ namespace storm { boost::optional> const& StandardRewardModel::getOptionalStateRewardVector() const { return this->optionalStateRewardVector; } + + template + ValueType const& StandardRewardModel::getStateReward(uint_fast64_t state) const { + assert(this->hasStateRewards()); + assert(state < this->optionalStateRewardVector.get().size()); + return this->optionalStateRewardVector.get()[state]; + } template bool StandardRewardModel::hasStateActionRewards() const { @@ -57,13 +66,22 @@ namespace storm { template std::vector const& StandardRewardModel::getStateActionRewardVector() const { + assert(this->hasStateActionRewards()); return this->optionalStateActionRewardVector.get(); } template std::vector& StandardRewardModel::getStateActionRewardVector() { + assert(this->hasStateActionRewards()); return this->optionalStateActionRewardVector.get(); } + + template + ValueType const& StandardRewardModel::getStateActionReward(uint_fast64_t choiceIndex) const { + assert(this->hasStateActionRewards()); + assert(choiceIndex < this->optionalStateActionRewardVector.get().size()); + return this->optionalStateActionRewardVector.get()[choiceIndex]; + } template boost::optional> const& StandardRewardModel::getOptionalStateActionRewardVector() const { diff --git a/src/models/sparse/StandardRewardModel.h b/src/models/sparse/StandardRewardModel.h index eb396d90a..144307e7c 100644 --- a/src/models/sparse/StandardRewardModel.h +++ b/src/models/sparse/StandardRewardModel.h @@ -74,6 +74,8 @@ namespace storm { * @return The state reward vector. */ std::vector& getStateRewardVector(); + + ValueType const& getStateReward(uint_fast64_t state) const; /*! * Retrieves an optional value that contains the state reward vector if there is one. @@ -104,7 +106,12 @@ namespace storm { * @return The state-action reward vector. */ std::vector& getStateActionRewardVector(); - + + /*! + * Retrieves the state-action reward for the given choice. + */ + ValueType const& getStateActionReward(uint_fast64_t choiceIndex) const; + /*! * Retrieves an optional value that contains the state-action reward vector if there is one. *