diff --git a/src/storm-parsers/parser/PrismParser.cpp b/src/storm-parsers/parser/PrismParser.cpp index 5d76736f7..6d3e7b239 100644 --- a/src/storm-parsers/parser/PrismParser.cpp +++ b/src/storm-parsers/parser/PrismParser.cpp @@ -428,13 +428,30 @@ namespace storm { } storm::prism::Formula PrismParser::createFormula(std::string const& formulaName, storm::expressions::Expression expression) { - // Only register formula in second run. This prevents the parser from accepting formulas that depend on future - // formulas. + // Only register formula in second run. + // This is necessary because the resulting type of the formula is only known in the second run. + // This prevents the parser from accepting formulas that depend on future formulas. + storm::expressions::Variable variable; if (this->secondRun) { - STORM_LOG_THROW(this->identifiers_.find(formulaName) == nullptr, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Duplicate identifier '" << formulaName << "'."); - this->identifiers_.add(formulaName, expression); + try { + if (expression.hasIntegerType()) { + variable = manager->declareIntegerVariable(formulaName); + } else if (expression.hasBooleanType()) { + variable = manager->declareBooleanVariable(formulaName); + } else { + STORM_LOG_ASSERT(expression.hasNumericalType(), "Unexpected type for formula expression of formula " << formulaName); + variable = manager->declareRationalVariable(formulaName); + } + this->identifiers_.add(formulaName, variable.getExpression()); + } catch (storm::exceptions::InvalidArgumentException const& e) { + if (manager->hasVariable(formulaName)) { + STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Duplicate identifier '" << formulaName << "'."); + } else { + STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": illegal identifier '" << formulaName << "'."); + } + } } - return storm::prism::Formula(formulaName, expression, this->getFilename()); + return storm::prism::Formula(variable, expression, this->getFilename()); } storm::prism::Label PrismParser::createLabel(std::string const& labelName, storm::expressions::Expression expression) const { diff --git a/src/storm/storage/prism/Formula.cpp b/src/storm/storage/prism/Formula.cpp index 3134bc033..1f1e7f09a 100644 --- a/src/storm/storage/prism/Formula.cpp +++ b/src/storm/storage/prism/Formula.cpp @@ -2,12 +2,16 @@ namespace storm { namespace prism { - Formula::Formula(std::string const& name, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), name(name), expression(expression) { + Formula::Formula(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), expression(expression) { // Intentionally left empty. } std::string const& Formula::getName() const { - return this->name; + return this->variable.getName(); + } + + storm::expressions::Variable const& Formula::getExpressionVariable() const { + return this->variable; } storm::expressions::Expression const& Formula::getExpression() const { @@ -15,11 +19,12 @@ namespace storm { } storm::expressions::Type const& Formula::getType() const { - return this->getExpression().getType(); + assert(this->getExpressionVariable().getType() == this->getExpression().getType()); + return this->getExpressionVariable().getType(); } Formula Formula::substitute(std::map const& substitution) const { - return Formula(this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); + return Formula(this->getExpressionVariable(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } std::ostream& operator<<(std::ostream& stream, Formula const& formula) { diff --git a/src/storm/storage/prism/Formula.h b/src/storm/storage/prism/Formula.h index 17d2f0cbd..656361111 100644 --- a/src/storm/storage/prism/Formula.h +++ b/src/storm/storage/prism/Formula.h @@ -15,12 +15,12 @@ namespace storm { /*! * Creates a formula with the given name and expression. * - * @param name The name of the formula. + * @param placeholder The placeholder variable that is used in expressions to represent this formula. * @param expression The expression associated with this formula. * @param filename The filename in which the transition reward is defined. * @param lineNumber The line number in which the transition reward is defined. */ - Formula(std::string const& name, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0); + Formula(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0); // Create default implementations of constructors/assignment. Formula() = default; @@ -36,6 +36,13 @@ namespace storm { */ std::string const& getName() const; + /*! + * Retrieves the placeholder variable that is used in expressions to represent this formula. + * + * @return The placeholder variable that is used in expressions to represent this formula. + */ + storm::expressions::Variable const& getExpressionVariable() const; + /*! * Retrieves the expression that is associated with this formula. * @@ -62,7 +69,7 @@ namespace storm { private: // The name of the formula. - std::string name; + storm::expressions::Variable variable; // A predicate that needs to be satisfied by states for the label to be attached. storm::expressions::Expression expression; diff --git a/src/storm/storage/prism/Program.cpp b/src/storm/storage/prism/Program.cpp index 45a4f5a1c..c0913257d 100644 --- a/src/storm/storage/prism/Program.cpp +++ b/src/storm/storage/prism/Program.cpp @@ -787,63 +787,79 @@ namespace storm { } Program Program::substituteConstants() const { + return substituteConstantsFormulas(true, false); + } + + Program Program::substituteFormulas() const { + return substituteConstantsFormulas(false, true); + } + + Program Program::substituteConstantsFormulas(bool substituteConstants, bool substituteFormulas) const { + // We start by creating the appropriate substitution. - std::map constantSubstitution; - std::vector newConstants(this->getConstants()); - for (uint_fast64_t constantIndex = 0; constantIndex < newConstants.size(); ++constantIndex) { - auto const& constant = newConstants[constantIndex]; + std::map substitution; + + // Start with substituting constants. In a sane model, constant definitions do not contain formulas. + std::vector newConstants; + newConstants.reserve(this->getNumberOfConstants()); + for (auto const& oldConstant : this->getConstants()) { + // apply the substitutions gathered so far to the constant definition *before* adding it to the substitution. + newConstants.push_back(oldConstant.substitute(substitution)); - // Put the corresponding expression in the substitution. - if (constant.isDefined()) { - constantSubstitution.emplace(constant.getExpressionVariable(), constant.getExpression().simplify()); - - // If there is at least one more constant to come, we substitute the constants we have so far. - if (constantIndex + 1 < newConstants.size()) { - newConstants[constantIndex + 1] = newConstants[constantIndex + 1].substitute(constantSubstitution); - } + // Put the corresponding expression in the substitution (if requested). + auto const& constant = newConstants.back(); + if (substituteConstants && constant.isDefined()) { + substitution.emplace(constant.getExpressionVariable(), constant.getExpression().simplify()); } } - // Now we can substitute the constants in all expressions appearing in the program. + // Secondly, handle the formulas. These might contain constants + std::vector newFormulas; + newFormulas.reserve(this->getNumberOfFormulas()); + for (auto const& oldFormula : this->getFormulas()) { + // apply the currently gathered substitutions on the formula definition *before* adding it to the substitution. + newFormulas.emplace_back(oldFormula.substitute(substitution)); + // Put the corresponding expression in the substitution (if requested). + auto const& formula = newFormulas.back(); + if (substituteFormulas) { + substitution.emplace(formula.getExpressionVariable(), formula.getExpression().simplify()); + } + } + + // Now we can substitute the constants/formulas in all expressions appearing in the program. std::vector newBooleanVariables; newBooleanVariables.reserve(this->getNumberOfGlobalBooleanVariables()); for (auto const& booleanVariable : this->getGlobalBooleanVariables()) { - newBooleanVariables.emplace_back(booleanVariable.substitute(constantSubstitution)); + newBooleanVariables.emplace_back(booleanVariable.substitute(substitution)); } std::vector newIntegerVariables; newBooleanVariables.reserve(this->getNumberOfGlobalIntegerVariables()); for (auto const& integerVariable : this->getGlobalIntegerVariables()) { - newIntegerVariables.emplace_back(integerVariable.substitute(constantSubstitution)); - } - - std::vector newFormulas; - newFormulas.reserve(this->getNumberOfFormulas()); - for (auto const& formula : this->getFormulas()) { - newFormulas.emplace_back(formula.substitute(constantSubstitution)); + newIntegerVariables.emplace_back(integerVariable.substitute(substitution)); } std::vector newModules; newModules.reserve(this->getNumberOfModules()); for (auto const& module : this->getModules()) { - newModules.emplace_back(module.substitute(constantSubstitution)); + newModules.emplace_back(module.substitute(substitution)); } std::vector newRewardModels; newRewardModels.reserve(this->getNumberOfRewardModels()); for (auto const& rewardModel : this->getRewardModels()) { - newRewardModels.emplace_back(rewardModel.substitute(constantSubstitution)); + newRewardModels.emplace_back(rewardModel.substitute(substitution)); } boost::optional newInitialConstruct; if (this->hasInitialConstruct()) { - newInitialConstruct = this->getInitialConstruct().substitute(constantSubstitution); + newInitialConstruct = this->getInitialConstruct().substitute(substitution); } std::vector