#include "src/storage/expressions/LinearCoefficientVisitor.h"

#include "src/storage/expressions/Expressions.h"
#include "src/exceptions/ExceptionMacros.h"
#include "src/exceptions/InvalidArgumentException.h"

namespace storm {
    namespace expressions {
        std::pair<SimpleValuation, double> LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
            expression.getBaseExpression().accept(this);
            return resultStack.top();
        }
        
        void LinearCoefficientVisitor::visit(IfThenElseExpression const* expression) {
            LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
        }
        
        void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
            LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
        }
        
        void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
            if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
                expression->getFirstOperand()->accept(this);
                std::pair<SimpleValuation, double> leftResult = resultStack.top();
                resultStack.pop();
                expression->getSecondOperand()->accept(this);
                std::pair<SimpleValuation, double>& rightResult = resultStack.top();
                
                // Now add the left result to the right result.
                for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
                    if (rightResult.first.containsDoubleIdentifier(identifier)) {
                        rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) + rightResult.first.getDoubleValue(identifier));
                    } else {
                        rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
                    }
                }
                rightResult.second += leftResult.second;
                return;
            } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
                expression->getFirstOperand()->accept(this);
                std::pair<SimpleValuation, double> leftResult = resultStack.top();
                resultStack.pop();
                expression->getSecondOperand()->accept(this);
                std::pair<SimpleValuation, double>& rightResult = resultStack.top();
                
                // Now subtract the right result from the left result.
                for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
                    if (rightResult.first.containsDoubleIdentifier(identifier)) {
                        rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) - rightResult.first.getDoubleValue(identifier));
                    } else {
                        rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
                    }
                }
                for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
                    if (!leftResult.first.containsDoubleIdentifier(identifier)) {
                        rightResult.first.setDoubleValue(identifier, -rightResult.first.getDoubleValue(identifier));
                    }
                }
                rightResult.second = leftResult.second - rightResult.second;
                return;
            } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
                expression->getFirstOperand()->accept(this);
                std::pair<SimpleValuation, double> leftResult = resultStack.top();
                resultStack.pop();
                expression->getSecondOperand()->accept(this);
                std::pair<SimpleValuation, double>& rightResult = resultStack.top();
                
                // If the expression is linear, either the left or the right side must not contain variables.
                LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
                if (leftResult.first.getNumberOfIdentifiers() == 0) {
                    for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
                        rightResult.first.setDoubleValue(identifier, leftResult.second * rightResult.first.getDoubleValue(identifier));
                    }
                } else {
                    for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
                        rightResult.first.addDoubleIdentifier(identifier, rightResult.second * leftResult.first.getDoubleValue(identifier));
                    }
                }
                rightResult.second *= leftResult.second;
                return;
            } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
                expression->getFirstOperand()->accept(this);
                std::pair<SimpleValuation, double> leftResult = resultStack.top();
                resultStack.pop();
                expression->getSecondOperand()->accept(this);
                std::pair<SimpleValuation, double>& rightResult = resultStack.top();
                
                // If the expression is linear, either the left or the right side must not contain variables.
                LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
                if (leftResult.first.getNumberOfIdentifiers() == 0) {
                    for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
                        rightResult.first.setDoubleValue(identifier, leftResult.second / rightResult.first.getDoubleValue(identifier));
                    }
                } else {
                    for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
                        rightResult.first.addDoubleIdentifier(identifier, leftResult.first.getDoubleValue(identifier) / rightResult.second);
                    }
                }
                rightResult.second = leftResult.second / leftResult.second;
                return;
            } else {
                LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
            }
        }
        
        void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) {
            LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
        }
        
        void LinearCoefficientVisitor::visit(VariableExpression const* expression) {
            SimpleValuation valuation;
            switch (expression->getReturnType()) {
                case ExpressionReturnType::Bool: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break;
                case ExpressionReturnType::Int:
                case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression->getVariableName(), 1); break;
                case ExpressionReturnType::Undefined: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break;
            }
            
            resultStack.push(std::make_pair(valuation, 0));
        }
        
        void LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const* expression) {
            LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
        }
        
        void LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const* expression) {
            if (expression->getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
                // Here, we need to negate all double identifiers.
                std::pair<SimpleValuation, double>& valuationConstantPair = resultStack.top();
                for (auto const& identifier : valuationConstantPair.first.getDoubleIdentifiers()) {
                    valuationConstantPair.first.setDoubleValue(identifier, -valuationConstantPair.first.getDoubleValue(identifier));
                }
            } else {
                LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
            }
        }
        
        void LinearCoefficientVisitor::visit(BooleanLiteralExpression const* expression) {
            LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
        }
        
        void LinearCoefficientVisitor::visit(IntegerLiteralExpression const* expression) {
            resultStack.push(std::make_pair(SimpleValuation(), static_cast<double>(expression->getValue())));
        }
        
        void LinearCoefficientVisitor::visit(DoubleLiteralExpression const* expression) {
            resultStack.push(std::make_pair(SimpleValuation(), expression->getValue()));
        }
    }
}