From c739f0befa6a2393f5a292f0f93ca3bb255d9cb5 Mon Sep 17 00:00:00 2001 From: TimQu Date: Fri, 14 Sep 2018 15:38:05 +0200 Subject: [PATCH] elimination of jani function --- src/storm-cli-utilities/model-handling.h | 20 +- src/storm-conv/api/storm-conv.cpp | 6 +- src/storm/storage/jani/Automaton.cpp | 5 +- src/storm/storage/jani/Automaton.h | 2 +- src/storm/storage/jani/FunctionDefinition.cpp | 19 + src/storm/storage/jani/FunctionDefinition.h | 8 + src/storm/storage/jani/FunctionEliminator.cpp | 403 ++++++++++++++++++ src/storm/storage/jani/FunctionEliminator.h | 19 + src/storm/storage/jani/LValue.cpp | 7 +- src/storm/storage/jani/LValue.h | 1 + src/storm/storage/jani/Model.cpp | 16 +- src/storm/storage/jani/Model.h | 11 +- .../expressions/FunctionCallExpression.cpp | 5 + .../jani/expressions/FunctionCallExpression.h | 1 + .../FunctionCallExpressionFinder.cpp | 124 ++++++ .../traverser/FunctionCallExpressionFinder.h | 20 + .../storage/jani/traverser/JaniTraverser.cpp | 20 + .../storage/jani/traverser/JaniTraverser.h | 2 + 18 files changed, 676 insertions(+), 13 deletions(-) create mode 100644 src/storm/storage/jani/FunctionEliminator.cpp create mode 100644 src/storm/storage/jani/FunctionEliminator.h create mode 100644 src/storm/storage/jani/traverser/FunctionCallExpressionFinder.cpp create mode 100644 src/storm/storage/jani/traverser/FunctionCallExpressionFinder.h diff --git a/src/storm-cli-utilities/model-handling.h b/src/storm-cli-utilities/model-handling.h index 057f3f436..f0dfc5e6e 100644 --- a/src/storm-cli-utilities/model-handling.h +++ b/src/storm-cli-utilities/model-handling.h @@ -154,11 +154,23 @@ namespace storm { } } + // Check whether transformations on the jani model are required if (output.model && output.model.get().isJaniModel()) { - // Check if arrays need to be eliminated - if (coreSettings.getEngine() != storm::settings::modules::CoreSettings::Engine::Sparse || buildSettings.isJitSet()) { - output.preprocessedProperties = output.properties; - output.model.get().asJaniModel().eliminateArrays(output.preprocessedProperties.get()); + auto& janiModel = output.model.get().asJaniModel(); + // Check if functions need to be eliminated + if (janiModel.getModelFeatures().hasFunctions()) { + if (!output.preprocessedProperties) { + output.preprocessedProperties = output.properties; + } + janiModel.substituteFunctions(output.preprocessedProperties.get()); + } + + // Check if arrays need to be eliminated. This should be done after! eliminating the functions + if (janiModel.getModelFeatures().hasArrays() && (coreSettings.getEngine() != storm::settings::modules::CoreSettings::Engine::Sparse || buildSettings.isJitSet())) { + if (!output.preprocessedProperties) { + output.preprocessedProperties = output.properties; + } + janiModel.eliminateArrays(output.preprocessedProperties.get()); } } return output; diff --git a/src/storm-conv/api/storm-conv.cpp b/src/storm-conv/api/storm-conv.cpp index 9a965fb65..c0e7252bd 100644 --- a/src/storm-conv/api/storm-conv.cpp +++ b/src/storm-conv/api/storm-conv.cpp @@ -39,9 +39,9 @@ namespace storm { janiModel.eliminateArrays(properties); } - //if (!options.allowFunctions && janiModel.getModelFeatures().hasFunctions()) { - //janiModel = janiModel.substituteFunctions(); - //} + if (!options.allowFunctions && janiModel.getModelFeatures().hasFunctions()) { + janiModel = janiModel.substituteFunctions(properties); + } if (options.modelName) { janiModel.setName(options.modelName.get()); diff --git a/src/storm/storage/jani/Automaton.cpp b/src/storm/storage/jani/Automaton.cpp index c1b375b2b..e8c3555a9 100644 --- a/src/storm/storage/jani/Automaton.cpp +++ b/src/storm/storage/jani/Automaton.cpp @@ -95,7 +95,7 @@ namespace storm { return functionDefinitions; } - std::unordered_map Automaton::getFunctionDefinitions() { + std::unordered_map& Automaton::getFunctionDefinitions() { return functionDefinitions; } @@ -396,6 +396,9 @@ namespace storm { } void Automaton::substitute(std::map const& substitution) { + for (auto& functionDefinition : this->getFunctionDefinitions()) { + functionDefinition.second.substitute(substitution); + } for (auto& variable : this->getVariables().getBoundedIntegerVariables()) { variable.substitute(substitution); } diff --git a/src/storm/storage/jani/Automaton.h b/src/storm/storage/jani/Automaton.h index ed67d2efc..980fbaca7 100644 --- a/src/storm/storage/jani/Automaton.h +++ b/src/storm/storage/jani/Automaton.h @@ -112,7 +112,7 @@ namespace storm { /*! * Retrieves all function definitions of this automaton */ - std::unordered_map getFunctionDefinitions(); + std::unordered_map& getFunctionDefinitions(); /*! * Retrieves whether the automaton has a location with the given name. diff --git a/src/storm/storage/jani/FunctionDefinition.cpp b/src/storm/storage/jani/FunctionDefinition.cpp index 2bac96f91..8be6beac6 100644 --- a/src/storm/storage/jani/FunctionDefinition.cpp +++ b/src/storm/storage/jani/FunctionDefinition.cpp @@ -1,4 +1,9 @@ #include "storm/storage/jani/FunctionDefinition.h" +#include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h" + + +#include "storm/utility/macros.h" +#include "storm/exceptions/InvalidArgumentException.h" namespace storm { namespace jani { @@ -23,6 +28,20 @@ namespace storm { return functionBody; } + storm::expressions::Expression FunctionDefinition::call(std::vector> const& arguments ) const { + // substitute the parameters in the function body + STORM_LOG_THROW(arguments.size() == parameters.size(), storm::exceptions::InvalidArgumentException, "The number of arguments does not match the number of parameters."); + std::unordered_map parameterSubstitution; + for (uint64_t i = 0; i < arguments.size(); ++i) { + parameterSubstitution.emplace(parameters[i], arguments[i]); + } + return substituteJaniExpression(functionBody, parameterSubstitution); + } + + void FunctionDefinition::substitute(std::map const& substitution) { + this->setFunctionBody(substituteJaniExpression(this->getFunctionBody(), substitution)); + } + void FunctionDefinition::setFunctionBody(storm::expressions::Expression const& body) { functionBody = body; } diff --git a/src/storm/storage/jani/FunctionDefinition.h b/src/storm/storage/jani/FunctionDefinition.h index 51eabf937..4a46bb663 100644 --- a/src/storm/storage/jani/FunctionDefinition.h +++ b/src/storm/storage/jani/FunctionDefinition.h @@ -43,6 +43,14 @@ namespace storm { */ void setFunctionBody(storm::expressions::Expression const& body); + /*! + * Calls the function with the given arguments + */ + storm::expressions::Expression call(std::vector> const& arguments ) const; + + void substitute(std::map const& substitution); + + private: // The name of the function. std::string name; diff --git a/src/storm/storage/jani/FunctionEliminator.cpp b/src/storm/storage/jani/FunctionEliminator.cpp new file mode 100644 index 000000000..50d5deb95 --- /dev/null +++ b/src/storm/storage/jani/FunctionEliminator.cpp @@ -0,0 +1,403 @@ +#include "storm/storage/jani/FunctionEliminator.h" + +#include + +#include "storm/storage/expressions/ExpressionVisitor.h" +#include "storm/storage/jani/expressions/JaniExpressionVisitor.h" +#include "storm/storage/jani/expressions/JaniExpressions.h" +#include "storm/storage/jani/Variable.h" +#include "storm/storage/jani/traverser/JaniTraverser.h" +#include "storm/storage/jani/traverser/FunctionCallExpressionFinder.h" +#include "storm/storage/jani/Model.h" +#include "storm/storage/jani/Property.h" + +#include "storm/storage/expressions/Expressions.h" +#include "storm/storage/expressions/ExpressionManager.h" + +#include "storm/exceptions/UnexpectedException.h" +#include "storm/exceptions/NotSupportedException.h" + +namespace storm { + + + + namespace jani { + namespace detail { + + class FunctionEliminationExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor { + public: + + typedef std::shared_ptr BaseExprPtr; + + FunctionEliminationExpressionVisitor(std::unordered_map const* globalFunctions, std::unordered_map const* localFunctions = nullptr) : globalFunctions(globalFunctions), localFunctions(localFunctions) {} + + virtual ~FunctionEliminationExpressionVisitor() = default; + + FunctionEliminationExpressionVisitor setLocalFunctions(std::unordered_map const* localFunctions) { + return FunctionEliminationExpressionVisitor(this->globalFunctions, localFunctions); + } + + storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) { + auto res = storm::expressions::Expression(boost::any_cast(expression.accept(*this, boost::any()))); + return res.simplify(); + } + + virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override { + BaseExprPtr conditionExpression = boost::any_cast(expression.getCondition()->accept(*this, data)); + BaseExprPtr thenExpression = boost::any_cast(expression.getThenExpression()->accept(*this, data)); + BaseExprPtr elseExpression = boost::any_cast(expression.getElseExpression()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::IfThenElseExpression(expression.getManager(), thenExpression->getType(), conditionExpression, thenExpression, elseExpression))); + } + } + + virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override { + BaseExprPtr firstExpression = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); + BaseExprPtr secondExpression = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); + } + } + + virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override { + BaseExprPtr firstExpression = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); + BaseExprPtr secondExpression = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); + } + } + + virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override { + BaseExprPtr firstExpression = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); + BaseExprPtr secondExpression = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getRelationType()))); + } + } + + virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const& data) override { + return expression.getSharedPointer(); + } + + virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override { + BaseExprPtr operandExpression = boost::any_cast(expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); + } + } + + virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override { + BaseExprPtr operandExpression = boost::any_cast(expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); + } + } + + virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override { + return expression.getSharedPointer(); + } + + virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override { + return expression.getSharedPointer(); + } + + virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override { + return expression.getSharedPointer(); + } + + virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override { + STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ")."); + uint64_t size = expression.size()->evaluateAsInt(); + std::vector elements; + bool changed = false; + for (uint64_t i = 0; i(expression.at(i)->accept(*this, data)); + if (element.get() != expression.at(i).get()) { + changed = true; + } + elements.push_back(element); + } + if (changed) { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::ValueArrayExpression(expression.getManager(), expression.getType(), elements))); + } else { + return expression.getSharedPointer(); + } + } + + virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override { + BaseExprPtr sizeExpression = boost::any_cast(expression.size()->accept(*this, data)); + BaseExprPtr elementExpression = boost::any_cast(expression.getElementExpression()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (sizeExpression.get() == expression.size().get() && elementExpression.get() == expression.getElementExpression().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::ConstructorArrayExpression(expression.getManager(), expression.getType(), sizeExpression, expression.getIndexVar(), elementExpression))); + } + } + + virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override { + BaseExprPtr firstExpression = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); + BaseExprPtr secondExpression = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new storm::expressions::ArrayAccessExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression))); + } + } + + virtual boost::any visit(storm::expressions::FunctionCallExpression const& expression, boost::any const& data) override { + // Find the associated function definition + FunctionDefinition const* funDef = nullptr; + if (localFunctions != nullptr) { + auto funDefIt = localFunctions->find(expression.getIdentifier()); + if (funDefIt != localFunctions->end()) { + funDef = &(funDefIt->second); + } + } + if (globalFunctions != nullptr) { + auto funDefIt = globalFunctions->find(expression.getIdentifier()); + if (funDefIt != globalFunctions->end()) { + funDef = &(funDefIt->second); + } + } + + STORM_LOG_THROW(funDef != nullptr, storm::exceptions::UnexpectedException, "Unable to find function definition for function call " << expression << "."); + + return boost::any_cast(funDef->call(expression.getArguments()).getBaseExpression().accept(*this, data)); + } + + private: + std::unordered_map const* globalFunctions; + std::unordered_map const* localFunctions; + }; + + class FunctionEliminatorTraverser : public JaniTraverser { + public: + + FunctionEliminatorTraverser() = default; + + virtual ~FunctionEliminatorTraverser() = default; + + void eliminate(Model& model, std::vector& properties) { + + // Replace all function calls by the function definition + traverse(model, boost::any()); + + // Replace function definitions in properties + if (!model.getGlobalFunctionDefinitions().empty()) { + FunctionEliminationExpressionVisitor v(&model.getGlobalFunctionDefinitions()); + for (auto& property : properties) { + property = property.substitute([&v](storm::expressions::Expression const& exp) {return v.eliminate(exp);}); + } + } + + // Erase function definitions in model and automata + model.getGlobalFunctionDefinitions().clear(); + for (auto& automaton : model.getAutomata()) { + automaton.getFunctionDefinitions().clear(); + } + + // Clear the model feature 'functions' + model.getModelFeatures().remove(ModelFeature::Functions); + + // finalize the model + model.finalize(); + } + + + + // To detect cyclic dependencies between function bodies, we need to eliminate the functions in a topological order + enum class FunctionDefinitionStatus {Unprocessed, Current, Processed}; + void eliminateFunctionsInFunctionBodies(FunctionEliminationExpressionVisitor& eliminationVisitor, std::unordered_map& functions, std::unordered_map& status, std::string const& current) { + status[current] = FunctionDefinitionStatus::Current; + FunctionDefinition& funDef = functions.find(current)->second; + auto calledFunctions = getOccurringFunctionCalls(funDef.getFunctionBody()); + for (auto const& calledFunction : calledFunctions) { + STORM_LOG_THROW(calledFunction != current, storm::exceptions::NotSupportedException, "Function '" << calledFunction << "' calls itself. This is not supported."); + auto calledStatus = status.find(calledFunction); + // Check whether the called function belongs to the ones that actually needed processing + if (calledStatus != status.end()) { + STORM_LOG_THROW(calledStatus->second != FunctionDefinitionStatus::Current, storm::exceptions::NotSupportedException, "Found cyclic dependencies between functions '" << calledFunction << "' and '" << current << "'. This is not supported."); + if (calledStatus->second == FunctionDefinitionStatus::Unprocessed) { + eliminateFunctionsInFunctionBodies(eliminationVisitor, functions, status, calledFunction); + } + } + } + // At this point, all called functions are processed already. So we can finally process this one. + funDef.setFunctionBody(eliminationVisitor.eliminate(funDef.getFunctionBody())); + status[current] = FunctionDefinitionStatus::Processed; + } + void eliminateFunctionsInFunctionBodies(FunctionEliminationExpressionVisitor& eliminationVisitor, std::unordered_map& functions) { + + std::unordered_map status; + for (auto const& f : functions) { + status.emplace(f.first, FunctionDefinitionStatus::Unprocessed); + } + for (auto const& f : functions) { + if (status[f.first] == FunctionDefinitionStatus::Unprocessed) { + eliminateFunctionsInFunctionBodies(eliminationVisitor, functions, status, f.first); + } + } + } + + virtual void traverse(Model& model, boost::any const& data) override { + + // First we need to apply functions called in function bodies + FunctionEliminationExpressionVisitor globalFunctionEliminationVisitor(&model.getGlobalFunctionDefinitions()); + eliminateFunctionsInFunctionBodies(globalFunctionEliminationVisitor, model.getGlobalFunctionDefinitions()); + + // Now run through the remaining components + for (auto& c : model.getConstants()) { + traverse(c, &globalFunctionEliminationVisitor); + } + JaniTraverser::traverse(model.getGlobalVariables(), &globalFunctionEliminationVisitor); + for (auto& aut : model.getAutomata()) { + traverse(aut, &globalFunctionEliminationVisitor); + } + if (model.hasInitialStatesRestriction()) { + model.setInitialStatesRestriction(globalFunctionEliminationVisitor.eliminate(model.getInitialStatesRestriction())); + } + } + + void traverse(Automaton& automaton, boost::any const& data) override { + // First we need to apply functions called in function bodies + auto functionEliminationVisitor = boost::any_cast(data)->setLocalFunctions(&automaton.getFunctionDefinitions()); + eliminateFunctionsInFunctionBodies(functionEliminationVisitor, automaton.getFunctionDefinitions()); + + // Now run through the remaining components + JaniTraverser::traverse(automaton.getVariables(), &functionEliminationVisitor); + for (auto& loc : automaton.getLocations()) { + JaniTraverser::traverse(loc, &functionEliminationVisitor); + } + JaniTraverser::traverse(automaton.getEdgeContainer(), &functionEliminationVisitor); + if (automaton.hasInitialStatesRestriction()) { + automaton.setInitialStatesRestriction(functionEliminationVisitor.eliminate(automaton.getInitialStatesRestriction())); + } + } + + void traverse(Constant& constant, boost::any const& data) override { + if (constant.isDefined()) { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + constant.define(functionEliminationVisitor->eliminate(constant.getExpression())); + } + } + + void traverse(BooleanVariable& variable, boost::any const& data) override { + if (variable.hasInitExpression()) { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression())); + } + } + + void traverse(BoundedIntegerVariable& variable, boost::any const& data) override { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + if (variable.hasInitExpression()) { + variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression())); + } + variable.setLowerBound(functionEliminationVisitor->eliminate(variable.getLowerBound())); + variable.setUpperBound(functionEliminationVisitor->eliminate(variable.getUpperBound())); + } + + void traverse(UnboundedIntegerVariable& variable, boost::any const& data) override { + if (variable.hasInitExpression()) { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression())); + } + } + + void traverse(RealVariable& variable, boost::any const& data) override { + if (variable.hasInitExpression()) { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression())); + } + } + + void traverse(ArrayVariable& variable, boost::any const& data) override { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + if (variable.hasInitExpression()) { + variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression())); + } + if (variable.hasElementTypeBounds()) { + variable.setElementTypeBounds(functionEliminationVisitor->eliminate(variable.getElementTypeBounds().first), functionEliminationVisitor->eliminate(variable.getElementTypeBounds().second)); + } + } + + void traverse(TemplateEdge& templateEdge, boost::any const& data) override { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + templateEdge.setGuard(functionEliminationVisitor->eliminate(templateEdge.getGuard())); + for (auto& dest : templateEdge.getDestinations()) { + JaniTraverser::traverse(dest, data); + } + JaniTraverser::traverse(templateEdge.getAssignments(), data); + } + + void traverse(Edge& edge, boost::any const& data) override { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + if (edge.hasRate()) { + edge.setRate(functionEliminationVisitor->eliminate(edge.getRate())); + } + for (auto& dest : edge.getDestinations()) { + traverse(dest, data); + } + } + + void traverse(EdgeDestination& edgeDestination, boost::any const& data) override { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + edgeDestination.setProbability(functionEliminationVisitor->eliminate(edgeDestination.getProbability())); + } + + void traverse(Assignment& assignment, boost::any const& data) override { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + assignment.setAssignedExpression(functionEliminationVisitor->eliminate(assignment.getAssignedExpression())); + traverse(assignment.getLValue(), data); + } + + void traverse(LValue& lValue, boost::any const& data) override { + if (lValue.isArrayAccess()) { + FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast(data); + lValue.setArrayIndex(functionEliminationVisitor->eliminate(lValue.getArrayIndex())); + } + } + + void traverse(storm::expressions::Expression const& expression, boost::any const& data) override { + STORM_LOG_THROW(getOccurringFunctionCalls(expression).empty(), storm::exceptions::UnexpectedException, "Did not translate functions in expression " << expression); + } + }; + } // namespace detail + + + void eliminateFunctions(Model& model, std::vector& properties) { + detail::FunctionEliminatorTraverser().eliminate(model, properties); + STORM_LOG_ASSERT(!containsFunctionCallExpression(model), "The model still seems to contain function calls."); + } + + } +} + diff --git a/src/storm/storage/jani/FunctionEliminator.h b/src/storm/storage/jani/FunctionEliminator.h new file mode 100644 index 000000000..791884aed --- /dev/null +++ b/src/storm/storage/jani/FunctionEliminator.h @@ -0,0 +1,19 @@ +#pragma once + + +#include + + +namespace storm { + namespace jani { + class Model; + class Property; + + /*! + * Eliminates all function references in the given model and the given properties by replacing them with their corresponding definitions. + */ + void eliminateFunctions(Model& model, std::vector& properties); + + } +} + diff --git a/src/storm/storage/jani/LValue.cpp b/src/storm/storage/jani/LValue.cpp index 7ae2b3e42..57d80bfdd 100644 --- a/src/storm/storage/jani/LValue.cpp +++ b/src/storm/storage/jani/LValue.cpp @@ -35,10 +35,15 @@ namespace storm { } storm::expressions::Expression const& LValue::getArrayIndex() const { - STORM_LOG_ASSERT(isArrayAccess(), "Tried to get the array index of an LValue, that is not an array access."); + STORM_LOG_ASSERT(isArrayAccess(), "Tried to get the array index of an LValue that is not an array access."); return arrayIndex; } + void LValue::setArrayIndex(storm::expressions::Expression const& newIndex) { + STORM_LOG_ASSERT(isArrayAccess(), "Tried to set the array index of an LValue that is not an array access."); + arrayIndex = newIndex; + } + bool LValue::isTransient() const { return variable->isTransient(); } diff --git a/src/storm/storage/jani/LValue.h b/src/storm/storage/jani/LValue.h index 0e8126bcd..1c6fa32c7 100644 --- a/src/storm/storage/jani/LValue.h +++ b/src/storm/storage/jani/LValue.h @@ -20,6 +20,7 @@ namespace storm { bool isArrayAccess() const; storm::jani::ArrayVariable const& getArray() const; storm::expressions::Expression const& getArrayIndex() const; + void setArrayIndex(storm::expressions::Expression const& newIndex); bool isTransient() const; bool operator< (LValue const& other) const; diff --git a/src/storm/storage/jani/Model.cpp b/src/storm/storage/jani/Model.cpp index e983cccab..5b3464ed4 100644 --- a/src/storm/storage/jani/Model.cpp +++ b/src/storm/storage/jani/Model.cpp @@ -17,6 +17,7 @@ #include "storm/storage/jani/Compositions.h" #include "storm/storage/jani/JSONExporter.h" #include "storm/storage/jani/ArrayEliminator.h" +#include "storm/storage/jani/FunctionEliminator.h" #include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h" #include "storm/storage/expressions/LinearityCheckVisitor.h" @@ -736,7 +737,7 @@ namespace storm { return globalFunctions; } - std::unordered_map Model::getGlobalFunctionDefinitions() { + std::unordered_map& Model::getGlobalFunctionDefinitions() { return globalFunctions; } @@ -927,6 +928,10 @@ namespace storm { } } + for (auto& functionDefinition : result.getGlobalFunctionDefinitions()) { + functionDefinition.second.substitute(constantSubstitution); + } + // Substitute constants in all global variables. for (auto& variable : result.getGlobalVariables().getBoundedIntegerVariables()) { variable.substitute(constantSubstitution); @@ -958,6 +963,15 @@ namespace storm { return result; } + void Model::substituteFunctions() { + std::vector emptyPropertyVector; + substituteFunctions(emptyPropertyVector); + } + + void Model::substituteFunctions(std::vector& properties) { + eliminateFunctions(*this, properties); + } + bool Model::containsArrayVariables() const { if (getGlobalVariables().containsArrayVariables()) { return true; diff --git a/src/storm/storage/jani/Model.h b/src/storm/storage/jani/Model.h index 2f0bbaa72..24faf8d76 100644 --- a/src/storm/storage/jani/Model.h +++ b/src/storm/storage/jani/Model.h @@ -262,7 +262,7 @@ namespace storm { /*! * Retrieves all global function definitions */ - std::unordered_map getGlobalFunctionDefinitions(); + std::unordered_map& getGlobalFunctionDefinitions(); /*! * Retrieves the manager responsible for the expressions in the JANI model. @@ -386,6 +386,13 @@ namespace storm { */ std::map getConstantsSubstitution() const; + /*! + * Substitutes all function calls with the corresponding function definition + * @param properties also eliminates function call expressions in the given properties + */ + void substituteFunctions(); + void substituteFunctions(std::vector& properties); + /*! * Returns true if at least one array variable occurs in the model. */ @@ -396,7 +403,7 @@ namespace storm { * @param keepNonTrivialArrayAccess if set, array access expressions in LValues and expressions are only replaced, if the index expression is constant. * @return data from the elimination. If non-trivial array accesses are kept, pointers to remaining array variables point to this data. */ - ArrayEliminatorData eliminateArrays(bool keepNonTrivialArrayAccess); + ArrayEliminatorData eliminateArrays(bool keepNonTrivialArrayAccess = false); /*! * Eliminates occurring array variables and expressions by replacing array variables by multiple basic variables. diff --git a/src/storm/storage/jani/expressions/FunctionCallExpression.cpp b/src/storm/storage/jani/expressions/FunctionCallExpression.cpp index 4ab6d64b4..783811461 100644 --- a/src/storm/storage/jani/expressions/FunctionCallExpression.cpp +++ b/src/storm/storage/jani/expressions/FunctionCallExpression.cpp @@ -71,5 +71,10 @@ namespace storm { STORM_LOG_THROW(i < arguments.size(), storm::exceptions::InvalidArgumentException, "Tried to access the argument with index " << i << " of a function call with " << arguments.size() << " arguments."); return arguments[i]; } + + std::vector> const& FunctionCallExpression::getArguments() const { + return arguments; + } + } } \ No newline at end of file diff --git a/src/storm/storage/jani/expressions/FunctionCallExpression.h b/src/storm/storage/jani/expressions/FunctionCallExpression.h index 75949247a..66ab8dacd 100644 --- a/src/storm/storage/jani/expressions/FunctionCallExpression.h +++ b/src/storm/storage/jani/expressions/FunctionCallExpression.h @@ -29,6 +29,7 @@ namespace storm { std::string const& getFunctionIdentifier() const; uint64_t getNumberOfArguments() const; std::shared_ptr getArgument(uint64_t i) const; + std::vector> const& getArguments() const; protected: diff --git a/src/storm/storage/jani/traverser/FunctionCallExpressionFinder.cpp b/src/storm/storage/jani/traverser/FunctionCallExpressionFinder.cpp new file mode 100644 index 000000000..f9f8b22ac --- /dev/null +++ b/src/storm/storage/jani/traverser/FunctionCallExpressionFinder.cpp @@ -0,0 +1,124 @@ +#include "storm/storage/jani/traverser/FunctionCallExpressionFinder.h" + +#include "storm/storage/jani/traverser/JaniTraverser.h" +#include "storm/storage/jani/expressions/JaniExpressionVisitor.h" +#include "storm/storage/jani/expressions/JaniExpressions.h" +#include "storm/storage/jani/Model.h" + +namespace storm { + namespace jani { + + namespace detail { + class FunctionCallExpressionFinderExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor { + public: + virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override { + expression.getCondition()->accept(*this, data); + expression.getThenExpression()->accept(*this, data); + expression.getElseExpression()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override { + expression.getFirstOperand()->accept(*this, data); + expression.getSecondOperand()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override { + expression.getFirstOperand()->accept(*this, data); + expression.getSecondOperand()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override { + expression.getFirstOperand()->accept(*this, data); + expression.getSecondOperand()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::VariableExpression const&, boost::any const&) override { + return boost::any(); + } + + virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override { + expression.getOperand()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override { + expression.getOperand()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::BooleanLiteralExpression const&, boost::any const&) override { + return boost::any(); + } + + virtual boost::any visit(storm::expressions::IntegerLiteralExpression const&, boost::any const&) override { + return boost::any(); + } + + virtual boost::any visit(storm::expressions::RationalLiteralExpression const&, boost::any const&) override { + return boost::any(); + } + + virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override { + STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ")."); + uint64_t size = expression.size()->evaluateAsInt(); + for (uint64_t i = 0; i < size; ++i) { + expression.at(i)->accept(*this, data); + } + return boost::any(); + } + + virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override { + expression.getElementExpression()->accept(*this, data); + expression.size()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override { + expression.getFirstOperand()->accept(*this, data); + expression.getSecondOperand()->accept(*this, data); + return boost::any(); + } + + virtual boost::any visit(storm::expressions::FunctionCallExpression const& expression, boost::any const& data) override { + auto& set = *boost::any_cast*>(data); + set.insert(expression.getIdentifier()); + for (uint64_t i = 0; i < expression.getNumberOfArguments(); ++i) { + expression.getArgument(i)->accept(*this, data); + } + return boost::any(); + } + }; + + class FunctionCallExpressionFinderTraverser : public ConstJaniTraverser { + public: + virtual void traverse(Model const& model, boost::any const& data) override { + ConstJaniTraverser::traverse(model, data); + } + + virtual void traverse(storm::expressions::Expression const& expression, boost::any const& data) override { + auto& res = *boost::any_cast(data); + res = res || !getOccurringFunctionCalls(expression).empty(); + } + }; + } + + + bool containsFunctionCallExpression(Model const& model) { + bool result = false; + detail::FunctionCallExpressionFinderTraverser().traverse(model, &result); + return result; + } + + std::unordered_set getOccurringFunctionCalls(storm::expressions::Expression const& expression) { + detail::FunctionCallExpressionFinderExpressionVisitor v; + std::unordered_set result; + expression.accept(v, &result); + return result; + } + } +} + diff --git a/src/storm/storage/jani/traverser/FunctionCallExpressionFinder.h b/src/storm/storage/jani/traverser/FunctionCallExpressionFinder.h new file mode 100644 index 000000000..e191eee4a --- /dev/null +++ b/src/storm/storage/jani/traverser/FunctionCallExpressionFinder.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +namespace storm { + + namespace expressions { + class Expression; + } + + namespace jani { + + class Model; + + bool containsFunctionCallExpression(Model const& model); + std::unordered_set getOccurringFunctionCalls(storm::expressions::Expression const& expr); + } +} + diff --git a/src/storm/storage/jani/traverser/JaniTraverser.cpp b/src/storm/storage/jani/traverser/JaniTraverser.cpp index 787682ede..ae912a3a8 100644 --- a/src/storm/storage/jani/traverser/JaniTraverser.cpp +++ b/src/storm/storage/jani/traverser/JaniTraverser.cpp @@ -11,6 +11,9 @@ namespace storm { for (auto& c : model.getConstants()) { traverse(c, data); } + for (auto& f : model.getGlobalFunctionDefinitions()) { + traverse(f.second, data); + } traverse(model.getGlobalVariables(), data); for (auto& aut : model.getAutomata()) { traverse(aut, data); @@ -26,6 +29,9 @@ namespace storm { void JaniTraverser::traverse(Automaton& automaton, boost::any const& data) { traverse(automaton.getVariables(), data); + for (auto& f : automaton.getFunctionDefinitions()) { + traverse(f.second, data); + } for (auto& loc : automaton.getLocations()) { traverse(loc, data); } @@ -41,6 +47,10 @@ namespace storm { } } + void JaniTraverser::traverse(FunctionDefinition& functionDefinition, boost::any const& data) { + traverse(functionDefinition.getFunctionBody(), data); + } + void JaniTraverser::traverse(VariableSet& variableSet, boost::any const& data) { for (auto& v : variableSet.getBooleanVariables()) { traverse(v, data); @@ -162,6 +172,9 @@ namespace storm { for (auto const& c : model.getConstants()) { traverse(c, data); } + for (auto const& f : model.getGlobalFunctionDefinitions()) { + traverse(f.second, data); + } traverse(model.getGlobalVariables(), data); for (auto const& aut : model.getAutomata()) { traverse(aut, data); @@ -177,6 +190,9 @@ namespace storm { void ConstJaniTraverser::traverse(Automaton const& automaton, boost::any const& data) { traverse(automaton.getVariables(), data); + for (auto const& f : automaton.getFunctionDefinitions()) { + traverse(f.second, data); + } for (auto const& loc : automaton.getLocations()) { traverse(loc, data); } @@ -192,6 +208,10 @@ namespace storm { } } + void ConstJaniTraverser::traverse(FunctionDefinition const& functionDefinition, boost::any const& data) { + traverse(functionDefinition.getFunctionBody(), data); + } + void ConstJaniTraverser::traverse(VariableSet const& variableSet, boost::any const& data) { for (auto const& v : variableSet.getBooleanVariables()) { traverse(v, data); diff --git a/src/storm/storage/jani/traverser/JaniTraverser.h b/src/storm/storage/jani/traverser/JaniTraverser.h index 685c76a85..a7da8bc6f 100644 --- a/src/storm/storage/jani/traverser/JaniTraverser.h +++ b/src/storm/storage/jani/traverser/JaniTraverser.h @@ -16,6 +16,7 @@ namespace storm { virtual void traverse(Action const& action, boost::any const& data); virtual void traverse(Automaton& automaton, boost::any const& data); virtual void traverse(Constant& constant, boost::any const& data); + virtual void traverse(FunctionDefinition& functionDefinition, boost::any const& data); virtual void traverse(VariableSet& variableSet, boost::any const& data); virtual void traverse(Location& location, boost::any const& data); virtual void traverse(BooleanVariable& variable, boost::any const& data); @@ -43,6 +44,7 @@ namespace storm { virtual void traverse(Action const& action, boost::any const& data); virtual void traverse(Automaton const& automaton, boost::any const& data); virtual void traverse(Constant const& constant, boost::any const& data); + virtual void traverse(FunctionDefinition const& functionDefinition, boost::any const& data); virtual void traverse(VariableSet const& variableSet, boost::any const& data); virtual void traverse(Location const& location, boost::any const& data); virtual void traverse(BooleanVariable const& variable, boost::any const& data);