diff --git a/src/storm/storage/jani/traverser/RewardModelInformation.cpp b/src/storm/storage/jani/traverser/RewardModelInformation.cpp new file mode 100644 index 000000000..a89139f4f --- /dev/null +++ b/src/storm/storage/jani/traverser/RewardModelInformation.cpp @@ -0,0 +1,100 @@ +#include "storm/storage/jani/traverser/RewardModelInformation.h" + +#include "storm/storage/expressions/Expression.h" +#include "storm/storage/expressions/Variable.h" +#include "storm/storage/jani/Model.h" + +namespace storm { + namespace jani { + + RewardModelInformation::RewardModelInformation(bool hasStateRewards, bool hasActionRewards, bool hasTransitionRewards) : stateRewards(hasStateRewards), actionRewards(hasActionRewards), transitionRewards(hasTransitionRewards) { + // Intentionally left empty + } + + RewardModelInformation::RewardModelInformation(Model const& model, std::string const& rewardModelNameIdentifier) : RewardModelInformation(model, model.getRewardModelExpression(rewardModelNameIdentifier)) { + // Intentionally left empty + } + + RewardModelInformation::RewardModelInformation(Model const& model, storm::expressions::Expression const& rewardModelExpression) : stateRewards(false), actionRewards(false), transitionRewards(false), destinationDependendRewards(false) { + auto variablesInRewardExpression = rewardExpression.getVariables(); + std::map initialSubstitution; + for (auto const& v : variablesInRewardExpression) { + STORM_LOG_ASSERT(model.hasGlobalVariable(v.getName()), "Unable to find global variable " << v.getName() << " occurring in a reward expression."); + auto const& janiVar = model.getGlobalVariable(v.getName()); + if (janiVar.hasInitExpression()) { + initialSubstitution.emplace(v, janiVar.getInitExpression()); + } + } + auto initExpr = storm::jani::substituteJaniExpression(rewardExpression, initialSubstitution); + if (initExpr.containsVariables() || !storm::utility::isZero(initExpr.evaluateAsRational())) { + stateRewards = true; + actionRewards = true; + transitionRewards = true; + } + traverse(model, &variablesInRewardExpression); + } + + void RewardModelInformation::traverse(Location const& location, boost::any const& data) { + auto const& vars = *boost::any_cast*>(data); + if (!hasStateRewards()) { + for (auto const& assignment : location.getAssignments().getTransientAssignments()) { + storm::jani::Variable const& assignedVariable = assignment.lValueIsArrayAccess() ? assignment.getLValue().getArray() : assignment.getVariable(); + if (vars.count(assignedVariable.getExpressionVariable()) > 0) { + stateRewards = true; + break; + } + } + } + } + + void RewardModelInformation::traverse(TemplateEdge const& templateEdge, boost::any const& data) { + auto const& vars = *boost::any_cast*>(data); + if (!hasActionRewards()) { + for (auto const& assignment : templateEdge.getAssignments().getTransientAssignments()) { + storm::jani::Variable const& assignedVariable = assignment.lValueIsArrayAccess() ? assignment.getLValue().getArray() : assignment.getVariable(); + if (vars.count(assignedVariable.getExpressionVariable()) > 0) { + actionRewards = true; + break; + } + } + } + for (auto const& dest : templateEdge.getDestinations()) { + traverse(dest, data); + } + } + + void RewardModelInformation::traverse(TemplateEdgeDestination const& templateEdgeDestination, boost::any const& data) { + auto const& vars = *boost::any_cast*>(data); + if (!hasTransitionRewards()) { + for (auto const& assignment : templateEdgeDestination.getOrderedAssignments().getTransientAssignments()) { + storm::jani::Variable const& assignedVariable = assignment.lValueIsArrayAccess() ? assignment.getLValue().getArray() : assignment.getVariable(); + if (vars.count(assignedVariable.getExpressionVariable()) > 0) { + transitionRewards = true; + break; + } + } + } + } + + RewardModelInformation RewardModelInformation::join(RewardModelInformation const& other) const { + return RewardModelInformation(this->hasStateRewards() || other.hasStateRewards(), + this->hasActionRewards() || other.hasActionRewards(), + this->hasTransitionRewards() || other.hasTransitionRewards()); + } + + bool RewardModelInformation::hasStateRewards () const { + return stateRewards; + } + + bool RewardModelInformation::hasActionRewards () const { + return actionRewards; + } + + bool RewardModelInformation::hasTransitionRewards () const { + return transitionReward; + } + + + } +} + diff --git a/src/storm/storage/jani/traverser/RewardModelInformation.h b/src/storm/storage/jani/traverser/RewardModelInformation.h new file mode 100644 index 000000000..c79f7a1dc --- /dev/null +++ b/src/storm/storage/jani/traverser/RewardModelInformation.h @@ -0,0 +1,59 @@ +#pragma once + + +#include + +#include "storm/storage/jani/traverser/JaniTraverser.h" + +namespace storm { + + namespace expressions { + class Variable; + class Expression; + } + + namespace jani { + + class Model; + + class RewardModelInformation : public ConstJaniTraverser { + public: + + RewardModelInformation(bool hasStateRewards, bool hasActionRewards, bool hasTransitionRewards); + RewardModelInformation(storm::jani::Model const& janiModel, std::string const& rewardModelNameIdentifier); + RewardModelInformation(storm::jani::Model const& janiModel, storm::expressions::Expression const& rewardModelExpression); + + virtual ~RewardModelInformation() = default; + using ConstJaniTraverser::traverse; + + virtual void traverse(Location const& location, boost::any const& data) override; + virtual void traverse(TemplateEdge const& templateEdge, boost::any const& data) override; + virtual void traverse(TemplateEdgeDestination const& TemplateEdgeDestination, boost::any const& data) override; + + /*! + * Returns the resulting information when joining the two reward models + */ + RewardModelInformation join(RewardModelInformation const& other) const; + + /*! + * Returns true iff the given reward model has state rewards + */ + bool hasStateRewards () const; + + /*! + * Returns true iff the given reward model has action rewards + */ + bool hasActionRewards () const; + + /*! + * Returns true iff the given reward model has transition rewards + */ + bool hasTransitionRewards () const; + + bool stateRewards; + bool actionRewards; + bool transitionRewards; + }; + } +} +