Browse Source

elimination of jani function

tempestpy_adaptions
TimQu 6 years ago
parent
commit
c739f0befa
  1. 18
      src/storm-cli-utilities/model-handling.h
  2. 6
      src/storm-conv/api/storm-conv.cpp
  3. 5
      src/storm/storage/jani/Automaton.cpp
  4. 2
      src/storm/storage/jani/Automaton.h
  5. 19
      src/storm/storage/jani/FunctionDefinition.cpp
  6. 8
      src/storm/storage/jani/FunctionDefinition.h
  7. 403
      src/storm/storage/jani/FunctionEliminator.cpp
  8. 19
      src/storm/storage/jani/FunctionEliminator.h
  9. 7
      src/storm/storage/jani/LValue.cpp
  10. 1
      src/storm/storage/jani/LValue.h
  11. 16
      src/storm/storage/jani/Model.cpp
  12. 11
      src/storm/storage/jani/Model.h
  13. 5
      src/storm/storage/jani/expressions/FunctionCallExpression.cpp
  14. 1
      src/storm/storage/jani/expressions/FunctionCallExpression.h
  15. 124
      src/storm/storage/jani/traverser/FunctionCallExpressionFinder.cpp
  16. 20
      src/storm/storage/jani/traverser/FunctionCallExpressionFinder.h
  17. 20
      src/storm/storage/jani/traverser/JaniTraverser.cpp
  18. 2
      src/storm/storage/jani/traverser/JaniTraverser.h

18
src/storm-cli-utilities/model-handling.h

@ -154,11 +154,23 @@ namespace storm {
} }
} }
// Check whether transformations on the jani model are required
if (output.model && output.model.get().isJaniModel()) { if (output.model && output.model.get().isJaniModel()) {
// Check if arrays need to be eliminated
if (coreSettings.getEngine() != storm::settings::modules::CoreSettings::Engine::Sparse || buildSettings.isJitSet()) {
auto& janiModel = output.model.get().asJaniModel();
// Check if functions need to be eliminated
if (janiModel.getModelFeatures().hasFunctions()) {
if (!output.preprocessedProperties) {
output.preprocessedProperties = output.properties; output.preprocessedProperties = output.properties;
output.model.get().asJaniModel().eliminateArrays(output.preprocessedProperties.get());
}
janiModel.substituteFunctions(output.preprocessedProperties.get());
}
// Check if arrays need to be eliminated. This should be done after! eliminating the functions
if (janiModel.getModelFeatures().hasArrays() && (coreSettings.getEngine() != storm::settings::modules::CoreSettings::Engine::Sparse || buildSettings.isJitSet())) {
if (!output.preprocessedProperties) {
output.preprocessedProperties = output.properties;
}
janiModel.eliminateArrays(output.preprocessedProperties.get());
} }
} }
return output; return output;

6
src/storm-conv/api/storm-conv.cpp

@ -39,9 +39,9 @@ namespace storm {
janiModel.eliminateArrays(properties); janiModel.eliminateArrays(properties);
} }
//if (!options.allowFunctions && janiModel.getModelFeatures().hasFunctions()) {
//janiModel = janiModel.substituteFunctions();
//}
if (!options.allowFunctions && janiModel.getModelFeatures().hasFunctions()) {
janiModel = janiModel.substituteFunctions(properties);
}
if (options.modelName) { if (options.modelName) {
janiModel.setName(options.modelName.get()); janiModel.setName(options.modelName.get());

5
src/storm/storage/jani/Automaton.cpp

@ -95,7 +95,7 @@ namespace storm {
return functionDefinitions; return functionDefinitions;
} }
std::unordered_map<std::string, FunctionDefinition> Automaton::getFunctionDefinitions() {
std::unordered_map<std::string, FunctionDefinition>& Automaton::getFunctionDefinitions() {
return functionDefinitions; return functionDefinitions;
} }
@ -396,6 +396,9 @@ namespace storm {
} }
void Automaton::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) { void Automaton::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) {
for (auto& functionDefinition : this->getFunctionDefinitions()) {
functionDefinition.second.substitute(substitution);
}
for (auto& variable : this->getVariables().getBoundedIntegerVariables()) { for (auto& variable : this->getVariables().getBoundedIntegerVariables()) {
variable.substitute(substitution); variable.substitute(substitution);
} }

2
src/storm/storage/jani/Automaton.h

@ -112,7 +112,7 @@ namespace storm {
/*! /*!
* Retrieves all function definitions of this automaton * Retrieves all function definitions of this automaton
*/ */
std::unordered_map<std::string, FunctionDefinition> getFunctionDefinitions();
std::unordered_map<std::string, FunctionDefinition>& getFunctionDefinitions();
/*! /*!
* Retrieves whether the automaton has a location with the given name. * Retrieves whether the automaton has a location with the given name.

19
src/storm/storage/jani/FunctionDefinition.cpp

@ -1,4 +1,9 @@
#include "storm/storage/jani/FunctionDefinition.h" #include "storm/storage/jani/FunctionDefinition.h"
#include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/InvalidArgumentException.h"
namespace storm { namespace storm {
namespace jani { namespace jani {
@ -23,6 +28,20 @@ namespace storm {
return functionBody; return functionBody;
} }
storm::expressions::Expression FunctionDefinition::call(std::vector<std::shared_ptr<storm::expressions::BaseExpression const>> const& arguments ) const {
// substitute the parameters in the function body
STORM_LOG_THROW(arguments.size() == parameters.size(), storm::exceptions::InvalidArgumentException, "The number of arguments does not match the number of parameters.");
std::unordered_map<storm::expressions::Variable, storm::expressions::Expression> parameterSubstitution;
for (uint64_t i = 0; i < arguments.size(); ++i) {
parameterSubstitution.emplace(parameters[i], arguments[i]);
}
return substituteJaniExpression(functionBody, parameterSubstitution);
}
void FunctionDefinition::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) {
this->setFunctionBody(substituteJaniExpression(this->getFunctionBody(), substitution));
}
void FunctionDefinition::setFunctionBody(storm::expressions::Expression const& body) { void FunctionDefinition::setFunctionBody(storm::expressions::Expression const& body) {
functionBody = body; functionBody = body;
} }

8
src/storm/storage/jani/FunctionDefinition.h

@ -43,6 +43,14 @@ namespace storm {
*/ */
void setFunctionBody(storm::expressions::Expression const& body); void setFunctionBody(storm::expressions::Expression const& body);
/*!
* Calls the function with the given arguments
*/
storm::expressions::Expression call(std::vector<std::shared_ptr<storm::expressions::BaseExpression const>> const& arguments ) const;
void substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution);
private: private:
// The name of the function. // The name of the function.
std::string name; std::string name;

403
src/storm/storage/jani/FunctionEliminator.cpp

@ -0,0 +1,403 @@
#include "storm/storage/jani/FunctionEliminator.h"
#include <unordered_map>
#include "storm/storage/expressions/ExpressionVisitor.h"
#include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
#include "storm/storage/jani/expressions/JaniExpressions.h"
#include "storm/storage/jani/Variable.h"
#include "storm/storage/jani/traverser/JaniTraverser.h"
#include "storm/storage/jani/traverser/FunctionCallExpressionFinder.h"
#include "storm/storage/jani/Model.h"
#include "storm/storage/jani/Property.h"
#include "storm/storage/expressions/Expressions.h"
#include "storm/storage/expressions/ExpressionManager.h"
#include "storm/exceptions/UnexpectedException.h"
#include "storm/exceptions/NotSupportedException.h"
namespace storm {
namespace jani {
namespace detail {
class FunctionEliminationExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
public:
typedef std::shared_ptr<storm::expressions::BaseExpression const> BaseExprPtr;
FunctionEliminationExpressionVisitor(std::unordered_map<std::string, FunctionDefinition> const* globalFunctions, std::unordered_map<std::string, FunctionDefinition> const* localFunctions = nullptr) : globalFunctions(globalFunctions), localFunctions(localFunctions) {}
virtual ~FunctionEliminationExpressionVisitor() = default;
FunctionEliminationExpressionVisitor setLocalFunctions(std::unordered_map<std::string, FunctionDefinition> const* localFunctions) {
return FunctionEliminationExpressionVisitor(this->globalFunctions, localFunctions);
}
storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) {
auto res = storm::expressions::Expression(boost::any_cast<BaseExprPtr>(expression.accept(*this, boost::any())));
return res.simplify();
}
virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
BaseExprPtr conditionExpression = boost::any_cast<BaseExprPtr>(expression.getCondition()->accept(*this, data));
BaseExprPtr thenExpression = boost::any_cast<BaseExprPtr>(expression.getThenExpression()->accept(*this, data));
BaseExprPtr elseExpression = boost::any_cast<BaseExprPtr>(expression.getElseExpression()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::IfThenElseExpression(expression.getManager(), thenExpression->getType(), conditionExpression, thenExpression, elseExpression)));
}
}
virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType())));
}
}
virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType())));
}
}
virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getRelationType())));
}
}
virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const& data) override {
return expression.getSharedPointer();
}
virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
BaseExprPtr operandExpression = boost::any_cast<BaseExprPtr>(expression.getOperand()->accept(*this, data));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType())));
}
}
virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
BaseExprPtr operandExpression = boost::any_cast<BaseExprPtr>(expression.getOperand()->accept(*this, data));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType())));
}
}
virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override {
return expression.getSharedPointer();
}
virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override {
return expression.getSharedPointer();
}
virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override {
return expression.getSharedPointer();
}
virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
uint64_t size = expression.size()->evaluateAsInt();
std::vector<BaseExprPtr> elements;
bool changed = false;
for (uint64_t i = 0; i<size; ++i) {
BaseExprPtr element = boost::any_cast<BaseExprPtr>(expression.at(i)->accept(*this, data));
if (element.get() != expression.at(i).get()) {
changed = true;
}
elements.push_back(element);
}
if (changed) {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::ValueArrayExpression(expression.getManager(), expression.getType(), elements)));
} else {
return expression.getSharedPointer();
}
}
virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
BaseExprPtr sizeExpression = boost::any_cast<BaseExprPtr>(expression.size()->accept(*this, data));
BaseExprPtr elementExpression = boost::any_cast<BaseExprPtr>(expression.getElementExpression()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (sizeExpression.get() == expression.size().get() && elementExpression.get() == expression.getElementExpression().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::ConstructorArrayExpression(expression.getManager(), expression.getType(), sizeExpression, expression.getIndexVar(), elementExpression)));
}
}
virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override {
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::ArrayAccessExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression)));
}
}
virtual boost::any visit(storm::expressions::FunctionCallExpression const& expression, boost::any const& data) override {
// Find the associated function definition
FunctionDefinition const* funDef = nullptr;
if (localFunctions != nullptr) {
auto funDefIt = localFunctions->find(expression.getIdentifier());
if (funDefIt != localFunctions->end()) {
funDef = &(funDefIt->second);
}
}
if (globalFunctions != nullptr) {
auto funDefIt = globalFunctions->find(expression.getIdentifier());
if (funDefIt != globalFunctions->end()) {
funDef = &(funDefIt->second);
}
}
STORM_LOG_THROW(funDef != nullptr, storm::exceptions::UnexpectedException, "Unable to find function definition for function call " << expression << ".");
return boost::any_cast<BaseExprPtr>(funDef->call(expression.getArguments()).getBaseExpression().accept(*this, data));
}
private:
std::unordered_map<std::string, FunctionDefinition> const* globalFunctions;
std::unordered_map<std::string, FunctionDefinition> const* localFunctions;
};
class FunctionEliminatorTraverser : public JaniTraverser {
public:
FunctionEliminatorTraverser() = default;
virtual ~FunctionEliminatorTraverser() = default;
void eliminate(Model& model, std::vector<Property>& properties) {
// Replace all function calls by the function definition
traverse(model, boost::any());
// Replace function definitions in properties
if (!model.getGlobalFunctionDefinitions().empty()) {
FunctionEliminationExpressionVisitor v(&model.getGlobalFunctionDefinitions());
for (auto& property : properties) {
property = property.substitute([&v](storm::expressions::Expression const& exp) {return v.eliminate(exp);});
}
}
// Erase function definitions in model and automata
model.getGlobalFunctionDefinitions().clear();
for (auto& automaton : model.getAutomata()) {
automaton.getFunctionDefinitions().clear();
}
// Clear the model feature 'functions'
model.getModelFeatures().remove(ModelFeature::Functions);
// finalize the model
model.finalize();
}
// To detect cyclic dependencies between function bodies, we need to eliminate the functions in a topological order
enum class FunctionDefinitionStatus {Unprocessed, Current, Processed};
void eliminateFunctionsInFunctionBodies(FunctionEliminationExpressionVisitor& eliminationVisitor, std::unordered_map<std::string, FunctionDefinition>& functions, std::unordered_map<std::string, FunctionDefinitionStatus>& status, std::string const& current) {
status[current] = FunctionDefinitionStatus::Current;
FunctionDefinition& funDef = functions.find(current)->second;
auto calledFunctions = getOccurringFunctionCalls(funDef.getFunctionBody());
for (auto const& calledFunction : calledFunctions) {
STORM_LOG_THROW(calledFunction != current, storm::exceptions::NotSupportedException, "Function '" << calledFunction << "' calls itself. This is not supported.");
auto calledStatus = status.find(calledFunction);
// Check whether the called function belongs to the ones that actually needed processing
if (calledStatus != status.end()) {
STORM_LOG_THROW(calledStatus->second != FunctionDefinitionStatus::Current, storm::exceptions::NotSupportedException, "Found cyclic dependencies between functions '" << calledFunction << "' and '" << current << "'. This is not supported.");
if (calledStatus->second == FunctionDefinitionStatus::Unprocessed) {
eliminateFunctionsInFunctionBodies(eliminationVisitor, functions, status, calledFunction);
}
}
}
// At this point, all called functions are processed already. So we can finally process this one.
funDef.setFunctionBody(eliminationVisitor.eliminate(funDef.getFunctionBody()));
status[current] = FunctionDefinitionStatus::Processed;
}
void eliminateFunctionsInFunctionBodies(FunctionEliminationExpressionVisitor& eliminationVisitor, std::unordered_map<std::string, FunctionDefinition>& functions) {
std::unordered_map<std::string, FunctionDefinitionStatus> status;
for (auto const& f : functions) {
status.emplace(f.first, FunctionDefinitionStatus::Unprocessed);
}
for (auto const& f : functions) {
if (status[f.first] == FunctionDefinitionStatus::Unprocessed) {
eliminateFunctionsInFunctionBodies(eliminationVisitor, functions, status, f.first);
}
}
}
virtual void traverse(Model& model, boost::any const& data) override {
// First we need to apply functions called in function bodies
FunctionEliminationExpressionVisitor globalFunctionEliminationVisitor(&model.getGlobalFunctionDefinitions());
eliminateFunctionsInFunctionBodies(globalFunctionEliminationVisitor, model.getGlobalFunctionDefinitions());
// Now run through the remaining components
for (auto& c : model.getConstants()) {
traverse(c, &globalFunctionEliminationVisitor);
}
JaniTraverser::traverse(model.getGlobalVariables(), &globalFunctionEliminationVisitor);
for (auto& aut : model.getAutomata()) {
traverse(aut, &globalFunctionEliminationVisitor);
}
if (model.hasInitialStatesRestriction()) {
model.setInitialStatesRestriction(globalFunctionEliminationVisitor.eliminate(model.getInitialStatesRestriction()));
}
}
void traverse(Automaton& automaton, boost::any const& data) override {
// First we need to apply functions called in function bodies
auto functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data)->setLocalFunctions(&automaton.getFunctionDefinitions());
eliminateFunctionsInFunctionBodies(functionEliminationVisitor, automaton.getFunctionDefinitions());
// Now run through the remaining components
JaniTraverser::traverse(automaton.getVariables(), &functionEliminationVisitor);
for (auto& loc : automaton.getLocations()) {
JaniTraverser::traverse(loc, &functionEliminationVisitor);
}
JaniTraverser::traverse(automaton.getEdgeContainer(), &functionEliminationVisitor);
if (automaton.hasInitialStatesRestriction()) {
automaton.setInitialStatesRestriction(functionEliminationVisitor.eliminate(automaton.getInitialStatesRestriction()));
}
}
void traverse(Constant& constant, boost::any const& data) override {
if (constant.isDefined()) {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
constant.define(functionEliminationVisitor->eliminate(constant.getExpression()));
}
}
void traverse(BooleanVariable& variable, boost::any const& data) override {
if (variable.hasInitExpression()) {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression()));
}
}
void traverse(BoundedIntegerVariable& variable, boost::any const& data) override {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
if (variable.hasInitExpression()) {
variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression()));
}
variable.setLowerBound(functionEliminationVisitor->eliminate(variable.getLowerBound()));
variable.setUpperBound(functionEliminationVisitor->eliminate(variable.getUpperBound()));
}
void traverse(UnboundedIntegerVariable& variable, boost::any const& data) override {
if (variable.hasInitExpression()) {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression()));
}
}
void traverse(RealVariable& variable, boost::any const& data) override {
if (variable.hasInitExpression()) {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression()));
}
}
void traverse(ArrayVariable& variable, boost::any const& data) override {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
if (variable.hasInitExpression()) {
variable.setInitExpression(functionEliminationVisitor->eliminate(variable.getInitExpression()));
}
if (variable.hasElementTypeBounds()) {
variable.setElementTypeBounds(functionEliminationVisitor->eliminate(variable.getElementTypeBounds().first), functionEliminationVisitor->eliminate(variable.getElementTypeBounds().second));
}
}
void traverse(TemplateEdge& templateEdge, boost::any const& data) override {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
templateEdge.setGuard(functionEliminationVisitor->eliminate(templateEdge.getGuard()));
for (auto& dest : templateEdge.getDestinations()) {
JaniTraverser::traverse(dest, data);
}
JaniTraverser::traverse(templateEdge.getAssignments(), data);
}
void traverse(Edge& edge, boost::any const& data) override {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
if (edge.hasRate()) {
edge.setRate(functionEliminationVisitor->eliminate(edge.getRate()));
}
for (auto& dest : edge.getDestinations()) {
traverse(dest, data);
}
}
void traverse(EdgeDestination& edgeDestination, boost::any const& data) override {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
edgeDestination.setProbability(functionEliminationVisitor->eliminate(edgeDestination.getProbability()));
}
void traverse(Assignment& assignment, boost::any const& data) override {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
assignment.setAssignedExpression(functionEliminationVisitor->eliminate(assignment.getAssignedExpression()));
traverse(assignment.getLValue(), data);
}
void traverse(LValue& lValue, boost::any const& data) override {
if (lValue.isArrayAccess()) {
FunctionEliminationExpressionVisitor* functionEliminationVisitor = boost::any_cast<FunctionEliminationExpressionVisitor*>(data);
lValue.setArrayIndex(functionEliminationVisitor->eliminate(lValue.getArrayIndex()));
}
}
void traverse(storm::expressions::Expression const& expression, boost::any const& data) override {
STORM_LOG_THROW(getOccurringFunctionCalls(expression).empty(), storm::exceptions::UnexpectedException, "Did not translate functions in expression " << expression);
}
};
} // namespace detail
void eliminateFunctions(Model& model, std::vector<Property>& properties) {
detail::FunctionEliminatorTraverser().eliminate(model, properties);
STORM_LOG_ASSERT(!containsFunctionCallExpression(model), "The model still seems to contain function calls.");
}
}
}

19
src/storm/storage/jani/FunctionEliminator.h

@ -0,0 +1,19 @@
#pragma once
#include <vector>
namespace storm {
namespace jani {
class Model;
class Property;
/*!
* Eliminates all function references in the given model and the given properties by replacing them with their corresponding definitions.
*/
void eliminateFunctions(Model& model, std::vector<Property>& properties);
}
}

7
src/storm/storage/jani/LValue.cpp

@ -35,10 +35,15 @@ namespace storm {
} }
storm::expressions::Expression const& LValue::getArrayIndex() const { storm::expressions::Expression const& LValue::getArrayIndex() const {
STORM_LOG_ASSERT(isArrayAccess(), "Tried to get the array index of an LValue, that is not an array access.");
STORM_LOG_ASSERT(isArrayAccess(), "Tried to get the array index of an LValue that is not an array access.");
return arrayIndex; return arrayIndex;
} }
void LValue::setArrayIndex(storm::expressions::Expression const& newIndex) {
STORM_LOG_ASSERT(isArrayAccess(), "Tried to set the array index of an LValue that is not an array access.");
arrayIndex = newIndex;
}
bool LValue::isTransient() const { bool LValue::isTransient() const {
return variable->isTransient(); return variable->isTransient();
} }

1
src/storm/storage/jani/LValue.h

@ -20,6 +20,7 @@ namespace storm {
bool isArrayAccess() const; bool isArrayAccess() const;
storm::jani::ArrayVariable const& getArray() const; storm::jani::ArrayVariable const& getArray() const;
storm::expressions::Expression const& getArrayIndex() const; storm::expressions::Expression const& getArrayIndex() const;
void setArrayIndex(storm::expressions::Expression const& newIndex);
bool isTransient() const; bool isTransient() const;
bool operator< (LValue const& other) const; bool operator< (LValue const& other) const;

16
src/storm/storage/jani/Model.cpp

@ -17,6 +17,7 @@
#include "storm/storage/jani/Compositions.h" #include "storm/storage/jani/Compositions.h"
#include "storm/storage/jani/JSONExporter.h" #include "storm/storage/jani/JSONExporter.h"
#include "storm/storage/jani/ArrayEliminator.h" #include "storm/storage/jani/ArrayEliminator.h"
#include "storm/storage/jani/FunctionEliminator.h"
#include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h" #include "storm/storage/jani/expressions/JaniExpressionSubstitutionVisitor.h"
#include "storm/storage/expressions/LinearityCheckVisitor.h" #include "storm/storage/expressions/LinearityCheckVisitor.h"
@ -736,7 +737,7 @@ namespace storm {
return globalFunctions; return globalFunctions;
} }
std::unordered_map<std::string, FunctionDefinition> Model::getGlobalFunctionDefinitions() {
std::unordered_map<std::string, FunctionDefinition>& Model::getGlobalFunctionDefinitions() {
return globalFunctions; return globalFunctions;
} }
@ -927,6 +928,10 @@ namespace storm {
} }
} }
for (auto& functionDefinition : result.getGlobalFunctionDefinitions()) {
functionDefinition.second.substitute(constantSubstitution);
}
// Substitute constants in all global variables. // Substitute constants in all global variables.
for (auto& variable : result.getGlobalVariables().getBoundedIntegerVariables()) { for (auto& variable : result.getGlobalVariables().getBoundedIntegerVariables()) {
variable.substitute(constantSubstitution); variable.substitute(constantSubstitution);
@ -958,6 +963,15 @@ namespace storm {
return result; return result;
} }
void Model::substituteFunctions() {
std::vector<Property> emptyPropertyVector;
substituteFunctions(emptyPropertyVector);
}
void Model::substituteFunctions(std::vector<Property>& properties) {
eliminateFunctions(*this, properties);
}
bool Model::containsArrayVariables() const { bool Model::containsArrayVariables() const {
if (getGlobalVariables().containsArrayVariables()) { if (getGlobalVariables().containsArrayVariables()) {
return true; return true;

11
src/storm/storage/jani/Model.h

@ -262,7 +262,7 @@ namespace storm {
/*! /*!
* Retrieves all global function definitions * Retrieves all global function definitions
*/ */
std::unordered_map<std::string, FunctionDefinition> getGlobalFunctionDefinitions();
std::unordered_map<std::string, FunctionDefinition>& getGlobalFunctionDefinitions();
/*! /*!
* Retrieves the manager responsible for the expressions in the JANI model. * Retrieves the manager responsible for the expressions in the JANI model.
@ -386,6 +386,13 @@ namespace storm {
*/ */
std::map<storm::expressions::Variable, storm::expressions::Expression> getConstantsSubstitution() const; std::map<storm::expressions::Variable, storm::expressions::Expression> getConstantsSubstitution() const;
/*!
* Substitutes all function calls with the corresponding function definition
* @param properties also eliminates function call expressions in the given properties
*/
void substituteFunctions();
void substituteFunctions(std::vector<Property>& properties);
/*! /*!
* Returns true if at least one array variable occurs in the model. * Returns true if at least one array variable occurs in the model.
*/ */
@ -396,7 +403,7 @@ namespace storm {
* @param keepNonTrivialArrayAccess if set, array access expressions in LValues and expressions are only replaced, if the index expression is constant. * @param keepNonTrivialArrayAccess if set, array access expressions in LValues and expressions are only replaced, if the index expression is constant.
* @return data from the elimination. If non-trivial array accesses are kept, pointers to remaining array variables point to this data. * @return data from the elimination. If non-trivial array accesses are kept, pointers to remaining array variables point to this data.
*/ */
ArrayEliminatorData eliminateArrays(bool keepNonTrivialArrayAccess);
ArrayEliminatorData eliminateArrays(bool keepNonTrivialArrayAccess = false);
/*! /*!
* Eliminates occurring array variables and expressions by replacing array variables by multiple basic variables. * Eliminates occurring array variables and expressions by replacing array variables by multiple basic variables.

5
src/storm/storage/jani/expressions/FunctionCallExpression.cpp

@ -71,5 +71,10 @@ namespace storm {
STORM_LOG_THROW(i < arguments.size(), storm::exceptions::InvalidArgumentException, "Tried to access the argument with index " << i << " of a function call with " << arguments.size() << " arguments."); STORM_LOG_THROW(i < arguments.size(), storm::exceptions::InvalidArgumentException, "Tried to access the argument with index " << i << " of a function call with " << arguments.size() << " arguments.");
return arguments[i]; return arguments[i];
} }
std::vector<std::shared_ptr<BaseExpression const>> const& FunctionCallExpression::getArguments() const {
return arguments;
}
} }
} }

1
src/storm/storage/jani/expressions/FunctionCallExpression.h

@ -29,6 +29,7 @@ namespace storm {
std::string const& getFunctionIdentifier() const; std::string const& getFunctionIdentifier() const;
uint64_t getNumberOfArguments() const; uint64_t getNumberOfArguments() const;
std::shared_ptr<BaseExpression const> getArgument(uint64_t i) const; std::shared_ptr<BaseExpression const> getArgument(uint64_t i) const;
std::vector<std::shared_ptr<BaseExpression const>> const& getArguments() const;
protected: protected:

124
src/storm/storage/jani/traverser/FunctionCallExpressionFinder.cpp

@ -0,0 +1,124 @@
#include "storm/storage/jani/traverser/FunctionCallExpressionFinder.h"
#include "storm/storage/jani/traverser/JaniTraverser.h"
#include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
#include "storm/storage/jani/expressions/JaniExpressions.h"
#include "storm/storage/jani/Model.h"
namespace storm {
namespace jani {
namespace detail {
class FunctionCallExpressionFinderExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
public:
virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
expression.getCondition()->accept(*this, data);
expression.getThenExpression()->accept(*this, data);
expression.getElseExpression()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
expression.getFirstOperand()->accept(*this, data);
expression.getSecondOperand()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
expression.getFirstOperand()->accept(*this, data);
expression.getSecondOperand()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
expression.getFirstOperand()->accept(*this, data);
expression.getSecondOperand()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::VariableExpression const&, boost::any const&) override {
return boost::any();
}
virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
expression.getOperand()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
expression.getOperand()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::BooleanLiteralExpression const&, boost::any const&) override {
return boost::any();
}
virtual boost::any visit(storm::expressions::IntegerLiteralExpression const&, boost::any const&) override {
return boost::any();
}
virtual boost::any visit(storm::expressions::RationalLiteralExpression const&, boost::any const&) override {
return boost::any();
}
virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
uint64_t size = expression.size()->evaluateAsInt();
for (uint64_t i = 0; i < size; ++i) {
expression.at(i)->accept(*this, data);
}
return boost::any();
}
virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
expression.getElementExpression()->accept(*this, data);
expression.size()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override {
expression.getFirstOperand()->accept(*this, data);
expression.getSecondOperand()->accept(*this, data);
return boost::any();
}
virtual boost::any visit(storm::expressions::FunctionCallExpression const& expression, boost::any const& data) override {
auto& set = *boost::any_cast<std::unordered_set<std::string>*>(data);
set.insert(expression.getIdentifier());
for (uint64_t i = 0; i < expression.getNumberOfArguments(); ++i) {
expression.getArgument(i)->accept(*this, data);
}
return boost::any();
}
};
class FunctionCallExpressionFinderTraverser : public ConstJaniTraverser {
public:
virtual void traverse(Model const& model, boost::any const& data) override {
ConstJaniTraverser::traverse(model, data);
}
virtual void traverse(storm::expressions::Expression const& expression, boost::any const& data) override {
auto& res = *boost::any_cast<bool*>(data);
res = res || !getOccurringFunctionCalls(expression).empty();
}
};
}
bool containsFunctionCallExpression(Model const& model) {
bool result = false;
detail::FunctionCallExpressionFinderTraverser().traverse(model, &result);
return result;
}
std::unordered_set<std::string> getOccurringFunctionCalls(storm::expressions::Expression const& expression) {
detail::FunctionCallExpressionFinderExpressionVisitor v;
std::unordered_set<std::string> result;
expression.accept(v, &result);
return result;
}
}
}

20
src/storm/storage/jani/traverser/FunctionCallExpressionFinder.h

@ -0,0 +1,20 @@
#pragma once
#include <unordered_set>
#include <string>
namespace storm {
namespace expressions {
class Expression;
}
namespace jani {
class Model;
bool containsFunctionCallExpression(Model const& model);
std::unordered_set<std::string> getOccurringFunctionCalls(storm::expressions::Expression const& expr);
}
}

20
src/storm/storage/jani/traverser/JaniTraverser.cpp

@ -11,6 +11,9 @@ namespace storm {
for (auto& c : model.getConstants()) { for (auto& c : model.getConstants()) {
traverse(c, data); traverse(c, data);
} }
for (auto& f : model.getGlobalFunctionDefinitions()) {
traverse(f.second, data);
}
traverse(model.getGlobalVariables(), data); traverse(model.getGlobalVariables(), data);
for (auto& aut : model.getAutomata()) { for (auto& aut : model.getAutomata()) {
traverse(aut, data); traverse(aut, data);
@ -26,6 +29,9 @@ namespace storm {
void JaniTraverser::traverse(Automaton& automaton, boost::any const& data) { void JaniTraverser::traverse(Automaton& automaton, boost::any const& data) {
traverse(automaton.getVariables(), data); traverse(automaton.getVariables(), data);
for (auto& f : automaton.getFunctionDefinitions()) {
traverse(f.second, data);
}
for (auto& loc : automaton.getLocations()) { for (auto& loc : automaton.getLocations()) {
traverse(loc, data); traverse(loc, data);
} }
@ -41,6 +47,10 @@ namespace storm {
} }
} }
void JaniTraverser::traverse(FunctionDefinition& functionDefinition, boost::any const& data) {
traverse(functionDefinition.getFunctionBody(), data);
}
void JaniTraverser::traverse(VariableSet& variableSet, boost::any const& data) { void JaniTraverser::traverse(VariableSet& variableSet, boost::any const& data) {
for (auto& v : variableSet.getBooleanVariables()) { for (auto& v : variableSet.getBooleanVariables()) {
traverse(v, data); traverse(v, data);
@ -162,6 +172,9 @@ namespace storm {
for (auto const& c : model.getConstants()) { for (auto const& c : model.getConstants()) {
traverse(c, data); traverse(c, data);
} }
for (auto const& f : model.getGlobalFunctionDefinitions()) {
traverse(f.second, data);
}
traverse(model.getGlobalVariables(), data); traverse(model.getGlobalVariables(), data);
for (auto const& aut : model.getAutomata()) { for (auto const& aut : model.getAutomata()) {
traverse(aut, data); traverse(aut, data);
@ -177,6 +190,9 @@ namespace storm {
void ConstJaniTraverser::traverse(Automaton const& automaton, boost::any const& data) { void ConstJaniTraverser::traverse(Automaton const& automaton, boost::any const& data) {
traverse(automaton.getVariables(), data); traverse(automaton.getVariables(), data);
for (auto const& f : automaton.getFunctionDefinitions()) {
traverse(f.second, data);
}
for (auto const& loc : automaton.getLocations()) { for (auto const& loc : automaton.getLocations()) {
traverse(loc, data); traverse(loc, data);
} }
@ -192,6 +208,10 @@ namespace storm {
} }
} }
void ConstJaniTraverser::traverse(FunctionDefinition const& functionDefinition, boost::any const& data) {
traverse(functionDefinition.getFunctionBody(), data);
}
void ConstJaniTraverser::traverse(VariableSet const& variableSet, boost::any const& data) { void ConstJaniTraverser::traverse(VariableSet const& variableSet, boost::any const& data) {
for (auto const& v : variableSet.getBooleanVariables()) { for (auto const& v : variableSet.getBooleanVariables()) {
traverse(v, data); traverse(v, data);

2
src/storm/storage/jani/traverser/JaniTraverser.h

@ -16,6 +16,7 @@ namespace storm {
virtual void traverse(Action const& action, boost::any const& data); virtual void traverse(Action const& action, boost::any const& data);
virtual void traverse(Automaton& automaton, boost::any const& data); virtual void traverse(Automaton& automaton, boost::any const& data);
virtual void traverse(Constant& constant, boost::any const& data); virtual void traverse(Constant& constant, boost::any const& data);
virtual void traverse(FunctionDefinition& functionDefinition, boost::any const& data);
virtual void traverse(VariableSet& variableSet, boost::any const& data); virtual void traverse(VariableSet& variableSet, boost::any const& data);
virtual void traverse(Location& location, boost::any const& data); virtual void traverse(Location& location, boost::any const& data);
virtual void traverse(BooleanVariable& variable, boost::any const& data); virtual void traverse(BooleanVariable& variable, boost::any const& data);
@ -43,6 +44,7 @@ namespace storm {
virtual void traverse(Action const& action, boost::any const& data); virtual void traverse(Action const& action, boost::any const& data);
virtual void traverse(Automaton const& automaton, boost::any const& data); virtual void traverse(Automaton const& automaton, boost::any const& data);
virtual void traverse(Constant const& constant, boost::any const& data); virtual void traverse(Constant const& constant, boost::any const& data);
virtual void traverse(FunctionDefinition const& functionDefinition, boost::any const& data);
virtual void traverse(VariableSet const& variableSet, boost::any const& data); virtual void traverse(VariableSet const& variableSet, boost::any const& data);
virtual void traverse(Location const& location, boost::any const& data); virtual void traverse(Location const& location, boost::any const& data);
virtual void traverse(BooleanVariable const& variable, boost::any const& data); virtual void traverse(BooleanVariable const& variable, boost::any const& data);

Loading…
Cancel
Save