diff --git a/src/storm-parsers/parser/JaniParser.cpp b/src/storm-parsers/parser/JaniParser.cpp index f3105d6ca..a5cc1c792 100644 --- a/src/storm-parsers/parser/JaniParser.cpp +++ b/src/storm-parsers/parser/JaniParser.cpp @@ -301,15 +301,11 @@ namespace storm { assert(bound == boost::none); STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "Forall and Exists are currently not supported"); } else if (opString == "Emin" || opString == "Emax") { - bool time=false; STORM_LOG_THROW(propertyStructure.count("exp") == 1, storm::exceptions::InvalidJaniException, "Expecting reward-expression for operator " << opString << " in " << scope.description); storm::expressions::Expression rewExpr = parseExpression(propertyStructure.at("exp"), scope.refine("Reward expression")); - if (rewExpr.isVariable()) { - time = false; - } else { - time = true; - } - + STORM_LOG_THROW(rewExpr.hasNumericalType(), storm::exceptions::InvalidJaniException, "Reward expression '" << rewExpr << "' does not have numerical type in " << scope.description); + std::string rewardName = rewExpr.toString(); + storm::logic::OperatorInformation opInfo; opInfo.optimalityType = opString == "Emin" ? storm::solver::OptimizationDirection::Minimize : storm::solver::OptimizationDirection::Maximize; opInfo.bound = bound; @@ -319,64 +315,43 @@ namespace storm { rewardAccumulation = parseRewardAccumulation(propertyStructure.at("accumulate"), scope.description); } + bool time = false; if (propertyStructure.count("step-instant") > 0) { STORM_LOG_THROW(propertyStructure.count("time-instant") == 0, storm::exceptions::NotSupportedException, "Storm does not support to have a step-instant and a time-instant in " + scope.description); STORM_LOG_THROW(propertyStructure.count("reward-instants") == 0, storm::exceptions::NotSupportedException, "Storm does not support to have a step-instant and a reward-instant in " + scope.description); storm::expressions::Expression stepInstantExpr = parseExpression(propertyStructure.at("step-instant"), scope.refine("Step instant")); if(rewardAccumulation.isEmpty()) { - if (rewExpr.isVariable()) { - std::string rewardName = rewExpr.getVariables().begin()->getName(); - return std::make_shared(std::make_shared(stepInstantExpr, storm::logic::TimeBoundType::Steps), rewardName, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Only simple reward expressions are currently supported"); - } + return std::make_shared(std::make_shared(stepInstantExpr, storm::logic::TimeBoundType::Steps), rewardName, opInfo); } else { - if (rewExpr.isVariable()) { - std::string rewardName = rewExpr.getVariables().begin()->getName(); - return std::make_shared(std::make_shared(storm::logic::TimeBound(false, stepInstantExpr), storm::logic::TimeBoundReference(storm::logic::TimeBoundType::Steps), rewardAccumulation), rewardName, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Only simple reward expressions are currently supported"); - } + return std::make_shared(std::make_shared(storm::logic::TimeBound(false, stepInstantExpr), storm::logic::TimeBoundReference(storm::logic::TimeBoundType::Steps), rewardAccumulation), rewardName, opInfo); } } else if (propertyStructure.count("time-instant") > 0) { STORM_LOG_THROW(propertyStructure.count("reward-instants") == 0, storm::exceptions::NotSupportedException, "Storm does not support to have a time-instant and a reward-instant in " + scope.description); - storm::expressions::Expression timeInstantExpr = parseExpression(propertyStructure.at("time-instant"), scope.refine("time instant")); - if(rewardAccumulation.isEmpty()) { - if (rewExpr.isVariable()) { - std::string rewardName = rewExpr.getVariables().begin()->getName(); - return std::make_shared(std::make_shared(timeInstantExpr, storm::logic::TimeBoundType::Time), rewardName, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Only simple reward expressions are currently supported"); - } + return std::make_shared(std::make_shared(timeInstantExpr, storm::logic::TimeBoundType::Time), rewardName, opInfo); } else { - if (rewExpr.isVariable()) { - std::string rewardName = rewExpr.getVariables().begin()->getName(); - return std::make_shared(std::make_shared(storm::logic::TimeBound(false, timeInstantExpr), storm::logic::TimeBoundReference(storm::logic::TimeBoundType::Time), rewardAccumulation), rewardName, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Only simple reward expressions are currently supported"); - } + return std::make_shared(std::make_shared(storm::logic::TimeBound(false, timeInstantExpr), storm::logic::TimeBoundReference(storm::logic::TimeBoundType::Time), rewardAccumulation), rewardName, opInfo); } } else if (propertyStructure.count("reward-instants") > 0) { std::vector bounds; std::vector boundReferences; for (auto const& rewInst : propertyStructure.at("reward-instants")) { - storm::expressions::Expression rewInstExpression = parseExpression(rewInst.at("exp"), scope.refine("Reward expression")); - STORM_LOG_THROW(!rewInstExpression.isVariable(), storm::exceptions::NotSupportedException, "Reward bounded cumulative reward formulas should only argue over reward expressions."); + storm::expressions::Expression rewInstRewardModelExpression = parseExpression(rewInst.at("exp"), scope.refine("Reward expression at reward instant")); + STORM_LOG_THROW(rewInstRewardModelExpression.hasNumericalType(), storm::exceptions::InvalidJaniException, "Reward expression '" << rewInstRewardModelExpression << "' does not have numerical type in " << scope.description); + std::string rewInstRewardModelName = rewInstRewardModelExpression.toString(); + if (!rewInstRewardModelExpression.isVariable()) { + nonTrivialRewardModelExpressions.emplace(rewInstRewardModelName, rewInstRewardModelExpression); + } storm::logic::RewardAccumulation boundRewardAccumulation = parseRewardAccumulation(rewInst.at("accumulate"), scope.description); - boundReferences.emplace_back(rewInstExpression.getVariables().begin()->getName(), boundRewardAccumulation); + boundReferences.emplace_back(rewInstRewardModelName, boundRewardAccumulation); storm::expressions::Expression rewInstantExpr = parseExpression(rewInst.at("instant"), scope.refine("reward instant")); bounds.emplace_back(false, rewInstantExpr); } - if (rewExpr.isVariable()) { - std::string rewardName = rewExpr.getVariables().begin()->getName(); - return std::make_shared(std::make_shared(bounds, boundReferences, rewardAccumulation), rewardName, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Only simple reward expressions are currently supported"); - } + return std::make_shared(std::make_shared(bounds, boundReferences, rewardAccumulation), rewardName, opInfo); } else { + time = !rewExpr.containsVariables() && storm::utility::isOne(rewExpr.evaluateAsRational()); std::shared_ptr subformula; if (propertyStructure.count("reach") > 0) { auto formulaContext = time ? storm::logic::FormulaContext::Time : storm::logic::FormulaContext::Reward; @@ -384,34 +359,16 @@ namespace storm { } else { subformula = std::make_shared(rewardAccumulation); } - if (rewExpr.isVariable()) { - assert(!time); - std::string rewardName = rewExpr.getVariables().begin()->getName(); - return std::make_shared(subformula, rewardName, opInfo); - } else if (!rewExpr.containsVariables()) { - assert(time); + if (time) { assert(subformula->isTotalRewardFormula() || subformula->isTimePathFormula()); - if(rewExpr.hasIntegerType()) { - if (rewExpr.evaluateAsInt() == 1) { - return std::make_shared(subformula, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Expected steps/time only works with constant one."); - } - } else if (rewExpr.hasRationalType()){ - if (rewExpr.evaluateAsDouble() == 1.0) { - - return std::make_shared(subformula, opInfo); - } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Expected steps/time only works with constant one."); - } - } else { - STORM_LOG_THROW(false, storm::exceptions::InvalidJaniException, "Only numerical reward expressions are allowed"); - } - + return std::make_shared(subformula, opInfo); } else { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "No complex reward expressions are supported at the moment"); + return std::make_shared(subformula, rewardName, opInfo); } } + if (!time && !rewExpr.isVariable()) { + nonTrivialRewardModelExpressions.emplace(rewardName, rewExpr); + } } else if (opString == "Smin" || opString == "Smax") { STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "Smin and Smax are currently not supported"); } else if (opString == "U" || opString == "F") { @@ -462,12 +419,14 @@ namespace storm { for (auto const& rbStructure : propertyStructure.at("reward-bounds")) { storm::jani::PropertyInterval pi = parsePropertyInterval(rbStructure.at("bounds"), scope.refine("reward-bounded until").clearVariables()); STORM_LOG_THROW(rbStructure.count("exp") == 1, storm::exceptions::InvalidJaniException, "Expecting reward-expression for operator " << opString << " in " << scope.description); - storm::expressions::Expression rewExpr = parseExpression(rbStructure.at("exp"), scope.refine("Reward expression")); - STORM_LOG_THROW(rewExpr.isVariable(), storm::exceptions::NotSupportedException, "Storm currently does not support complex reward expressions."); + storm::expressions::Expression rewInstRewardModelExpression = parseExpression(rbStructure.at("exp"), scope.refine("Reward expression at reward-bounds")); + STORM_LOG_THROW(rewInstRewardModelExpression.hasNumericalType(), storm::exceptions::InvalidJaniException, "Reward expression '" << rewInstRewardModelExpression << "' does not have numerical type in " << scope.description); + std::string rewInstRewardModelName = rewInstRewardModelExpression.toString(); + if (!rewInstRewardModelExpression.isVariable()) { + nonTrivialRewardModelExpressions.emplace(rewInstRewardModelName, rewInstRewardModelExpression); + } storm::logic::RewardAccumulation boundRewardAccumulation = parseRewardAccumulation(rbStructure.at("accumulate"), scope.description); - tbReferences.emplace_back(rewExpr.getVariables().begin()->getName(), boundRewardAccumulation); - std::string rewardName = rewExpr.getVariables().begin()->getName(); - STORM_LOG_WARN("Reward-type (steps, time) is deduced from model type."); + tbReferences.emplace_back(rewInstRewardModelName, boundRewardAccumulation); if (pi.hasLowerBound()) { lowerBounds.push_back(storm::logic::TimeBound(pi.lowerBoundStrict, pi.lowerBound)); } else { @@ -478,7 +437,6 @@ namespace storm { } else { upperBounds.push_back(boost::none); } - tbReferences.push_back(storm::logic::TimeBoundReference(rewardName)); } } if (!tbReferences.empty()) { diff --git a/src/storm-parsers/parser/JaniParser.h b/src/storm-parsers/parser/JaniParser.h index 304715801..4e7489b32 100644 --- a/src/storm-parsers/parser/JaniParser.h +++ b/src/storm-parsers/parser/JaniParser.h @@ -122,6 +122,7 @@ namespace storm { std::shared_ptr expressionManager; std::set labels = {}; + std::unordered_map nonTrivialRewardModelExpressions; bool allowRecursion = true; diff --git a/src/storm/storage/jani/ArrayEliminator.cpp b/src/storm/storage/jani/ArrayEliminator.cpp index e39a5d7e3..9bedb334a 100644 --- a/src/storm/storage/jani/ArrayEliminator.cpp +++ b/src/storm/storage/jani/ArrayEliminator.cpp @@ -445,6 +445,9 @@ namespace storm { if (model.hasInitialStatesRestriction()) { model.setInitialStatesRestriction(arrayExprEliminator->eliminate(model.getInitialStatesRestriction())); } + for (auto& nonTrivRew : model.getNonTrivialRewardExpressions()) { + nonTrivRew.second = arrayExprEliminator->eliminate(nonTrivRew.second); + } } virtual void traverse(Automaton& automaton, boost::any const& data) override { diff --git a/src/storm/storage/jani/FunctionEliminator.cpp b/src/storm/storage/jani/FunctionEliminator.cpp index 152dfc77f..fcdaa0b34 100644 --- a/src/storm/storage/jani/FunctionEliminator.cpp +++ b/src/storm/storage/jani/FunctionEliminator.cpp @@ -284,6 +284,9 @@ namespace storm { if (model.hasInitialStatesRestriction()) { model.setInitialStatesRestriction(globalFunctionEliminationVisitor.eliminate(model.getInitialStatesRestriction())); } + for (auto& nonTrivRew : model.getNonTrivialRewardExpressions()) { + nonTrivRew.second = globalFunctionEliminationVisitor.eliminate(nonTrivRew.second); + } } void traverse(Automaton& automaton, boost::any const& data) override { @@ -404,6 +407,11 @@ namespace storm { STORM_LOG_ASSERT(!containsFunctionCallExpression(model), "The model still seems to contain function calls."); } + storm::expressions::Expression eliminateFunctionCallsInExpression(storm::expressions::Expression const& expression, Model const& model) { + detail::FunctionEliminationExpressionVisitor visitor(&model.getGlobalFunctionDefinitions()); + return visitor.eliminate(expression); + } + } } diff --git a/src/storm/storage/jani/FunctionEliminator.h b/src/storm/storage/jani/FunctionEliminator.h index 791884aed..c7a3b74eb 100644 --- a/src/storm/storage/jani/FunctionEliminator.h +++ b/src/storm/storage/jani/FunctionEliminator.h @@ -5,6 +5,11 @@ namespace storm { + + namespace expressions { + class Expression; + } + namespace jani { class Model; class Property; @@ -13,6 +18,12 @@ namespace storm { * Eliminates all function references in the given model and the given properties by replacing them with their corresponding definitions. */ void eliminateFunctions(Model& model, std::vector& properties); + + /*! + * Eliminates all function calls in the given expression by replacing them with their corresponding definitions. + * Only global function definitions are considered. + */ + storm::expressions::Expression eliminateFunctionCallsInExpression(storm::expressions::Expression const& expression, Model const& model); } } diff --git a/src/storm/storage/jani/JSONExporter.cpp b/src/storm/storage/jani/JSONExporter.cpp index bf83d36f0..be2c74471 100644 --- a/src/storm/storage/jani/JSONExporter.cpp +++ b/src/storm/storage/jani/JSONExporter.cpp @@ -32,6 +32,8 @@ #include "storm/storage/jani/Property.h" #include "storm/storage/jani/traverser/AssignmentsFinder.h" #include "storm/storage/jani/expressions/JaniReduceNestingExpressionVisitor.h" +#include "storm/storage/jani/FunctionEliminator.h" +#include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h" namespace storm { namespace jani { @@ -162,20 +164,35 @@ namespace storm { } modernjson::json FormulaToJaniJson::constructRewardAccumulation(storm::logic::RewardAccumulation const& rewardAccumulation, std::string const& rewardModelName) const { + bool steps = false; + bool time = false; + bool exit = false; - storm::jani::Variable const& transientVar = model.getGlobalVariable(rewardModelName); - storm::jani::AssignmentsFinder::ResultType assignmentKinds; - STORM_LOG_THROW(model.hasGlobalVariable(rewardModelName), storm::exceptions::InvalidPropertyException, "Unable to find transient variable with name " << rewardModelName << "."); - if (transientVar.getInitExpression().containsVariables() || !storm::utility::isZero(transientVar.getInitExpression().evaluateAsRational())) { - assignmentKinds.hasLocationAssignment = true; - assignmentKinds.hasEdgeAssignment = true; - assignmentKinds.hasEdgeDestinationAssignment = true; + auto rewardExpression = storm::jani::eliminateFunctionCallsInExpression(model.getRewardModelExpression(rewardModelName), model); + + auto variablesInRewardExpression = rewardExpression.getVariables(); + std::map initialSubstitution; + for (auto const& v : variablesInRewardExpression) { + STORM_LOG_ASSERT(model.hasGlobalVariable(v.getName()), "Unable to find global variable " << v.getName() << " occurring in a reward expression."); + auto const& janiVar = model.getGlobalVariable(v.getName()); + if (janiVar.hasInitExpression()) { + initialSubstitution.emplace(v, janiVar.getInitExpression()); + } + auto assignmentKinds = storm::jani::AssignmentsFinder().find(model, v); + steps = steps || assignmentKinds.hasEdgeAssignment || assignmentKinds.hasEdgeDestinationAssignment; + time = time || (!model.isDeterministicModel() && assignmentKinds.hasLocationAssignment); + exit = exit || assignmentKinds.hasLocationAssignment; } - assignmentKinds = storm::jani::AssignmentsFinder().find(model, transientVar); - - bool steps = rewardAccumulation.isStepsSet() && (assignmentKinds.hasEdgeAssignment || assignmentKinds.hasEdgeDestinationAssignment); - bool time = rewardAccumulation.isTimeSet() && !model.isDiscreteTimeModel() && assignmentKinds.hasLocationAssignment; - bool exit = rewardAccumulation.isExitSet() && assignmentKinds.hasLocationAssignment; + storm::jani::substituteJaniExpression(rewardExpression, initialSubstitution); + if (rewardExpression.containsVariables() || !storm::utility::isZero(rewardExpression.evaluateAsRational())) { + steps = true; + time = true; + exit = true; + } + + steps = steps && rewardAccumulation.isStepsSet(); + time = time && rewardAccumulation.isTimeSet(); + exit = exit && rewardAccumulation.isExitSet(); return constructRewardAccumulation(storm::logic::RewardAccumulation(steps, time, exit)); } @@ -265,7 +282,7 @@ namespace storm { opDecl["step-bounds"] = propertyInterval; } else if(tbr.isRewardBound()) { modernjson::json rewbound; - rewbound["exp"] = tbr.getRewardName(); + rewbound["exp"] = buildExpression(model.getRewardModelExpression(tbr.getRewardName()), model.getConstants(), model.getGlobalVariables()); if (tbr.hasRewardAccumulation()) { rewbound["accumulate"] = constructRewardAccumulation(tbr.getRewardAccumulation(), tbr.getRewardName()); } else { @@ -512,7 +529,7 @@ namespace storm { opDecl["left"][instantName] = buildExpression(f.getSubformula().asInstantaneousRewardFormula().getBound(), model.getConstants(), model.getGlobalVariables()); } STORM_LOG_THROW(f.hasRewardModelName(), storm::exceptions::NotSupportedException, "Reward name has to be specified for Jani-conversion"); - opDecl["left"]["exp"] = rewardModelName; + opDecl["left"]["exp"] = buildExpression(model.getRewardModelExpression(rewardModelName), model.getConstants(), model.getGlobalVariables()); opDecl["right"] = buildExpression(bound.threshold, model.getConstants(), model.getGlobalVariables()); } else { if (f.hasOptimalityType()) { @@ -541,7 +558,7 @@ namespace storm { } else if (f.getSubformula().isInstantaneousRewardFormula()) { opDecl[instantName] = buildExpression(f.getSubformula().asInstantaneousRewardFormula().getBound(), model.getConstants(), model.getGlobalVariables()); } - opDecl["exp"] = rewardModelName; + opDecl["exp"] = buildExpression(model.getRewardModelExpression(rewardModelName), model.getConstants(), model.getGlobalVariables()); } return opDecl; } diff --git a/src/storm/storage/jani/Model.cpp b/src/storm/storage/jani/Model.cpp index 2e9a981d3..79a37d384 100644 --- a/src/storm/storage/jani/Model.cpp +++ b/src/storm/storage/jani/Model.cpp @@ -79,6 +79,7 @@ namespace storm { this->constants = other.constants; this->constantToIndex = other.constantToIndex; this->globalVariables = other.globalVariables; + this->nonTrivialRewardModels = other.nonTrivialRewardModels; this->automata = other.automata; this->automatonToIndex = other.automatonToIndex; this->composition = other.composition; @@ -435,6 +436,8 @@ namespace storm { // Otherwise, we need to actually flatten composition. Model flattenedModel(this->getName() + "_flattened", this->getModelType(), this->getJaniVersion(), this->getManager().shared_from_this()); + + flattenedModel.getModelFeatures() = getModelFeatures(); // Get an SMT solver for computing possible guard combinations. std::unique_ptr solver = smtSolverFactory->create(*expressionManager); @@ -463,6 +466,14 @@ namespace storm { flattenedModel.addConstant(constant); } + for (auto const& nonTrivRew : getNonTrivialRewardExpressions()) { + flattenedModel.addNonTrivialRewardExpression(nonTrivRew.first, nonTrivRew.second); + } + + for (auto const& funDef : getGlobalFunctionDefinitions()) { + flattenedModel.addFunctionDefinition(funDef.second); + } + std::vector> composedAutomata; for (auto const& element : parallelComposition.getSubcompositions()) { STORM_LOG_THROW(element->isAutomatonComposition(), storm::exceptions::WrongFormatException, "Cannot flatten recursive (not standard-compliant) composition."); @@ -558,6 +569,9 @@ namespace storm { if (automaton.get().hasInitialStatesRestriction()) { initialStatesRestriction = initialStatesRestriction && automaton.get().getInitialStatesRestriction(); } + for (auto const& funDef : automaton.get().getFunctionDefinitions()) { + newAutomaton.addFunctionDefinition(funDef.second); + } } newAutomaton.setInitialStatesRestriction(this->getInitialStatesExpression(composedAutomata)); @@ -748,7 +762,63 @@ namespace storm { storm::expressions::ExpressionManager& Model::getExpressionManager() const { return *expressionManager; } - + + bool Model::addNonTrivialRewardExpression(std::string const& identifier, storm::expressions::Expression const& rewardExpression) { + if (nonTrivialRewardModels.count(identifier) > 0) { + return false; + } else { + nonTrivialRewardModels.emplace(identifier, rewardExpression); + return true; + } + } + + storm::expressions::Expression Model::getRewardModelExpression(std::string const& identifier) const { + auto findRes = nonTrivialRewardModels.find(identifier); + if (findRes != nonTrivialRewardModels.end()) { + return findRes->second; + } else { + // Check whether the reward model refers to a global variable + if (globalVariables.hasVariable(identifier)) { + return globalVariables.getVariable(identifier).getExpressionVariable().getExpression(); + } else { + STORM_LOG_THROW(identifier.empty(), storm::exceptions::InvalidArgumentException, "Cannot find unknown reward model '" << identifier << "'."); + STORM_LOG_THROW(nonTrivialRewardModels.size() + globalVariables.getNumberOfNumericalTransientVariables() == 1, storm::exceptions::InvalidArgumentException, "Reference to standard reward model is ambiguous."); + if (nonTrivialRewardModels.size() == 1) { + return nonTrivialRewardModels.begin()->second; + } else { + for (auto const& variable : globalVariables.getTransientVariables()) { + if (variable.isRealVariable() || variable.isUnboundedIntegerVariable() || variable.isBoundedIntegerVariable()) { + return variable.getExpressionVariable().getExpression(); + } + } + } + } + } + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot find unknown reward model '" << identifier << "'."); + return storm::expressions::Expression(); + } + + std::vector> Model::getAllRewardModelExpressions() const { + std::vector> result; + for (auto const& nonTrivExpr : nonTrivialRewardModels) { + result.emplace_back(nonTrivExpr.first, nonTrivExpr.second); + } + for (auto const& variable : globalVariables.getTransientVariables()) { + if (variable.isRealVariable() || variable.isUnboundedIntegerVariable() || variable.isBoundedIntegerVariable()) { + result.emplace_back(variable.getName(), variable.getExpressionVariable().getExpression()); + } + } + return result; + } + + std::unordered_map const& Model::getNonTrivialRewardExpressions() const { + return nonTrivialRewardModels; + } + + std::unordered_map& Model::getNonTrivialRewardExpressions() { + return nonTrivialRewardModels; + } + uint64_t Model::addAutomaton(Automaton const& automaton) { auto it = automatonToIndex.find(automaton.getName()); STORM_LOG_THROW(it == automatonToIndex.end(), storm::exceptions::WrongFormatException, "Automaton with name '" << automaton.getName() << "' already exists."); @@ -947,6 +1017,10 @@ namespace storm { // Substitute constants in initial states expression. result.setInitialStatesRestriction(substituteJaniExpression(this->getInitialStatesRestriction(), constantSubstitution)); + for (auto& rewMod : result.getNonTrivialRewardExpressions()) { + rewMod.second = substituteJaniExpression(rewMod.second, constantSubstitution); + } + // Substitute constants in variables of automata and their edges. for (auto& automaton : result.getAutomata()) { automaton.substitute(constantSubstitution); @@ -996,6 +1070,10 @@ namespace storm { // Substitute in initial states expression. this->setInitialStatesRestriction(substituteJaniExpression(this->getInitialStatesRestriction(), substitution)); + for (auto& rewMod : getNonTrivialRewardExpressions()) { + rewMod.second = substituteJaniExpression(rewMod.second, substitution); + } + // Substitute in variables of automata and their edges. for (auto& automaton : this->getAutomata()) { automaton.substitute(substitution); diff --git a/src/storm/storage/jani/Model.h b/src/storm/storage/jani/Model.h index 0eff007e5..c0740a80c 100644 --- a/src/storm/storage/jani/Model.h +++ b/src/storm/storage/jani/Model.h @@ -270,6 +270,33 @@ namespace storm { */ storm::expressions::ExpressionManager& getExpressionManager() const; + /*! + * Adds a (non-trivial) reward model, i.e., a reward model that does not consist of a single, global, numerical variable. + * @return true if a new reward model was added and false if a reward model with this identifier is already present in the model (in which case no reward model is added) + */ + bool addNonTrivialRewardExpression(std::string const& identifier, storm::expressions::Expression const& rewardExpression); + + /*! + * Retrieves the defining reward expression of the reward model with the given identifier. + */ + storm::expressions::Expression getRewardModelExpression(std::string const& identifier) const; + + /*! + * Retrieves all available reward model names and expressions of the model. + * This includes defined non-trivial reward expressions as well as transient, global, numerical variables + */ + std::vector> getAllRewardModelExpressions() const; + + /*! + * Retrieves all available non-trivial reward model names and expressions of the model. + */ + std::unordered_map const& getNonTrivialRewardExpressions() const; + + /*! + * Retrieves all available non-trivial reward model names and expressions of the model. + */ + std::unordered_map& getNonTrivialRewardExpressions(); + /*! * Adds the given automaton to the automata of this model. */ @@ -614,6 +641,10 @@ namespace storm { /// A mapping from names to action indices. std::unordered_map actionToIndex; + /// A mapping from non-trivial reward model names to their defining expression. + /// (A reward model is trivial, if it is represented by a single, global, numeric variable) + std::unordered_map nonTrivialRewardModels; + /// The set of non-silent action indices. boost::container::flat_set nonsilentActionIndices; diff --git a/src/storm/storage/jani/VariableSet.cpp b/src/storm/storage/jani/VariableSet.cpp index f6a47e34c..a3af625d5 100644 --- a/src/storm/storage/jani/VariableSet.cpp +++ b/src/storm/storage/jani/VariableSet.cpp @@ -280,6 +280,16 @@ namespace storm { return result; } + uint_fast64_t VariableSet::getNumberOfNumericalTransientVariables() const { + uint_fast64_t result = 0; + for (auto const& variable : transientVariables) { + if (variable->isRealVariable() || variable->isUnboundedIntegerVariable() || variable->isBoundedIntegerVariable()) { + ++result; + } + } + return result; + } + typename detail::ConstVariables VariableSet::getTransientVariables() const { return detail::ConstVariables(transientVariables.begin(), transientVariables.end()); } diff --git a/src/storm/storage/jani/VariableSet.h b/src/storm/storage/jani/VariableSet.h index 227535e6f..4e4d2d511 100644 --- a/src/storm/storage/jani/VariableSet.h +++ b/src/storm/storage/jani/VariableSet.h @@ -218,6 +218,11 @@ namespace storm { */ uint_fast64_t getNumberOfUnboundedIntegerTransientVariables() const; + /*! + * Retrieves the number of numerical (i.e. real, or integer) transient variables in this variable set. + */ + uint_fast64_t getNumberOfNumericalTransientVariables() const; + /*! * Retrieves the transient variables in this variable set. */ diff --git a/src/storm/storage/jani/traverser/JaniTraverser.cpp b/src/storm/storage/jani/traverser/JaniTraverser.cpp index 72f979b2d..7682c9a88 100644 --- a/src/storm/storage/jani/traverser/JaniTraverser.cpp +++ b/src/storm/storage/jani/traverser/JaniTraverser.cpp @@ -21,6 +21,9 @@ namespace storm { if (model.hasInitialStatesRestriction()) { traverse(model.getInitialStatesRestriction(), data); } + for (auto& nonTrivRew : model.getNonTrivialRewardExpressions()) { + traverse(nonTrivRew.second, data); + } } void JaniTraverser::traverse(Action const& action, boost::any const& data) { @@ -184,6 +187,9 @@ namespace storm { if (model.hasInitialStatesRestriction()) { traverse(model.getInitialStatesRestriction(), data); } + for (auto const& nonTrivRew : model.getNonTrivialRewardExpressions()) { + traverse(nonTrivRew.second, data); + } } void ConstJaniTraverser::traverse(Action const& action, boost::any const& data) {