Browse Source

Added a traverser that finds out, whether a given reward model has state/action/transition rewards

tempestpy_adaptions
TimQu 6 years ago
parent
commit
793228c150
  1. 100
      src/storm/storage/jani/traverser/RewardModelInformation.cpp
  2. 59
      src/storm/storage/jani/traverser/RewardModelInformation.h

100
src/storm/storage/jani/traverser/RewardModelInformation.cpp

@ -0,0 +1,100 @@
#include "storm/storage/jani/traverser/RewardModelInformation.h"
#include "storm/storage/expressions/Expression.h"
#include "storm/storage/expressions/Variable.h"
#include "storm/storage/jani/Model.h"
namespace storm {
namespace jani {
RewardModelInformation::RewardModelInformation(bool hasStateRewards, bool hasActionRewards, bool hasTransitionRewards) : stateRewards(hasStateRewards), actionRewards(hasActionRewards), transitionRewards(hasTransitionRewards) {
// Intentionally left empty
}
RewardModelInformation::RewardModelInformation(Model const& model, std::string const& rewardModelNameIdentifier) : RewardModelInformation(model, model.getRewardModelExpression(rewardModelNameIdentifier)) {
// Intentionally left empty
}
RewardModelInformation::RewardModelInformation(Model const& model, storm::expressions::Expression const& rewardModelExpression) : stateRewards(false), actionRewards(false), transitionRewards(false), destinationDependendRewards(false) {
auto variablesInRewardExpression = rewardExpression.getVariables();
std::map<storm::expressions::Variable, storm::expressions::Expression> initialSubstitution;
for (auto const& v : variablesInRewardExpression) {
STORM_LOG_ASSERT(model.hasGlobalVariable(v.getName()), "Unable to find global variable " << v.getName() << " occurring in a reward expression.");
auto const& janiVar = model.getGlobalVariable(v.getName());
if (janiVar.hasInitExpression()) {
initialSubstitution.emplace(v, janiVar.getInitExpression());
}
}
auto initExpr = storm::jani::substituteJaniExpression(rewardExpression, initialSubstitution);
if (initExpr.containsVariables() || !storm::utility::isZero(initExpr.evaluateAsRational())) {
stateRewards = true;
actionRewards = true;
transitionRewards = true;
}
traverse(model, &variablesInRewardExpression);
}
void RewardModelInformation::traverse(Location const& location, boost::any const& data) {
auto const& vars = *boost::any_cast<std::set<storm::expressions::Variable>*>(data);
if (!hasStateRewards()) {
for (auto const& assignment : location.getAssignments().getTransientAssignments()) {
storm::jani::Variable const& assignedVariable = assignment.lValueIsArrayAccess() ? assignment.getLValue().getArray() : assignment.getVariable();
if (vars.count(assignedVariable.getExpressionVariable()) > 0) {
stateRewards = true;
break;
}
}
}
}
void RewardModelInformation::traverse(TemplateEdge const& templateEdge, boost::any const& data) {
auto const& vars = *boost::any_cast<std::set<storm::expressions::Variable>*>(data);
if (!hasActionRewards()) {
for (auto const& assignment : templateEdge.getAssignments().getTransientAssignments()) {
storm::jani::Variable const& assignedVariable = assignment.lValueIsArrayAccess() ? assignment.getLValue().getArray() : assignment.getVariable();
if (vars.count(assignedVariable.getExpressionVariable()) > 0) {
actionRewards = true;
break;
}
}
}
for (auto const& dest : templateEdge.getDestinations()) {
traverse(dest, data);
}
}
void RewardModelInformation::traverse(TemplateEdgeDestination const& templateEdgeDestination, boost::any const& data) {
auto const& vars = *boost::any_cast<std::set<storm::expressions::Variable>*>(data);
if (!hasTransitionRewards()) {
for (auto const& assignment : templateEdgeDestination.getOrderedAssignments().getTransientAssignments()) {
storm::jani::Variable const& assignedVariable = assignment.lValueIsArrayAccess() ? assignment.getLValue().getArray() : assignment.getVariable();
if (vars.count(assignedVariable.getExpressionVariable()) > 0) {
transitionRewards = true;
break;
}
}
}
}
RewardModelInformation RewardModelInformation::join(RewardModelInformation const& other) const {
return RewardModelInformation(this->hasStateRewards() || other.hasStateRewards(),
this->hasActionRewards() || other.hasActionRewards(),
this->hasTransitionRewards() || other.hasTransitionRewards());
}
bool RewardModelInformation::hasStateRewards () const {
return stateRewards;
}
bool RewardModelInformation::hasActionRewards () const {
return actionRewards;
}
bool RewardModelInformation::hasTransitionRewards () const {
return transitionReward;
}
}
}

59
src/storm/storage/jani/traverser/RewardModelInformation.h

@ -0,0 +1,59 @@
#pragma once
#include <boost/any.hpp>
#include "storm/storage/jani/traverser/JaniTraverser.h"
namespace storm {
namespace expressions {
class Variable;
class Expression;
}
namespace jani {
class Model;
class RewardModelInformation : public ConstJaniTraverser {
public:
RewardModelInformation(bool hasStateRewards, bool hasActionRewards, bool hasTransitionRewards);
RewardModelInformation(storm::jani::Model const& janiModel, std::string const& rewardModelNameIdentifier);
RewardModelInformation(storm::jani::Model const& janiModel, storm::expressions::Expression const& rewardModelExpression);
virtual ~RewardModelInformation() = default;
using ConstJaniTraverser::traverse;
virtual void traverse(Location const& location, boost::any const& data) override;
virtual void traverse(TemplateEdge const& templateEdge, boost::any const& data) override;
virtual void traverse(TemplateEdgeDestination const& TemplateEdgeDestination, boost::any const& data) override;
/*!
* Returns the resulting information when joining the two reward models
*/
RewardModelInformation join(RewardModelInformation const& other) const;
/*!
* Returns true iff the given reward model has state rewards
*/
bool hasStateRewards () const;
/*!
* Returns true iff the given reward model has action rewards
*/
bool hasActionRewards () const;
/*!
* Returns true iff the given reward model has transition rewards
*/
bool hasTransitionRewards () const;
bool stateRewards;
bool actionRewards;
bool transitionRewards;
};
}
}
Loading…
Cancel
Save