|
|
@ -155,24 +155,20 @@ namespace storm { |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
template<typename MatrixValueType> |
|
|
|
ValueType StandardRewardModel<ValueType>::getTotalStateActionReward(uint_fast64_t stateIndex, uint_fast64_t choiceIndex, storm::storage::SparseMatrix<MatrixValueType> const& transitionMatrix, MatrixValueType const& stateRewardWeight, MatrixValueType const& actionRewardWeight) const { |
|
|
|
ValueType result = this->hasStateRewards() ? (this->hasStateActionRewards() ? (ValueType) (this->getStateReward(stateIndex) * stateRewardWeight + this->getStateActionReward(choiceIndex) * actionRewardWeight) |
|
|
|
: (ValueType) (this->getStateReward(stateIndex) * stateRewardWeight)) |
|
|
|
: (this->hasStateActionRewards() ? (ValueType) (this->getStateActionReward(choiceIndex) * actionRewardWeight) |
|
|
|
: storm::utility::zero<ValueType>()); |
|
|
|
ValueType StandardRewardModel<ValueType>::getStateActionAndTransitionReward(uint_fast64_t choiceIndex, storm::storage::SparseMatrix<MatrixValueType> const& transitionMatrix) const { |
|
|
|
ValueType result = this->hasStateActionRewards() ? this->getStateActionReward(choiceIndex) : storm::utility::zero<ValueType>(); |
|
|
|
if (this->hasTransitionRewards()) { |
|
|
|
auto rewMatrixEntryIt = this->getTransitionRewardMatrix().begin(choiceIndex); |
|
|
|
for (auto const& transitionEntry : transitionMatrix.getRow(choiceIndex)) { |
|
|
|
assert(rewMatrixEntryIt != this->getTransitionRewardMatrix().end(choiceIndex)); |
|
|
|
if (transitionEntry.getColumn() < rewMatrixEntryIt->getColumn()) { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
// We assume that the transition reward matrix is a submatrix of the given transition matrix. Hence, the following must hold
|
|
|
|
assert(transitionEntry.getColumn() == rewMatrixEntryIt->getColumn()); |
|
|
|
result += actionRewardWeight * rewMatrixEntryIt->getValue() * storm::utility::convertNumber<ValueType>(transitionEntry.getValue()); |
|
|
|
++rewMatrixEntryIt; |
|
|
|
} |
|
|
|
} |
|
|
|
result += transitionMatrix.getPointwiseProductRowSum(getTransitionRewardMatrix(), choiceIndex); |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
template<typename MatrixValueType> |
|
|
|
ValueType StandardRewardModel<ValueType>::getTotalStateActionReward(uint_fast64_t stateIndex, uint_fast64_t choiceIndex, storm::storage::SparseMatrix<MatrixValueType> const& transitionMatrix, MatrixValueType const& stateRewardWeight, MatrixValueType const& actionRewardWeight) const { |
|
|
|
ValueType result = actionRewardWeight * getStateActionAndTransitionReward(choiceIndex, transitionMatrix); |
|
|
|
if (this->hasStateRewards()) { |
|
|
|
result += stateRewardWeight * this->getStateReward(stateIndex); |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|