#include "src/storage/jani/Automaton.h"

#include "src/utility/macros.h"
#include "src/exceptions/WrongFormatException.h"
#include "src/exceptions/InvalidArgumentException.h"
#include "src/exceptions/InvalidTypeException.h"

namespace storm {
    namespace jani {
        
        namespace detail {
            Edges::Edges(iterator it, iterator ite) : it(it), ite(ite) {
                // Intentionally left empty.
            }
            
            Edges::iterator Edges::begin() const {
                return it;
            }
            
            Edges::iterator Edges::end() const {
                return ite;
            }
            
            bool Edges::empty() const {
                return it == ite;
            }
            
            std::size_t Edges::size() const {
                return std::distance(it, ite);
            }
            
            ConstEdges::ConstEdges(const_iterator it, const_iterator ite) : it(it), ite(ite) {
                // Intentionally left empty.
            }
            
            ConstEdges::const_iterator ConstEdges::begin() const {
                return it;
            }
            
            ConstEdges::const_iterator ConstEdges::end() const {
                return ite;
            }

            bool ConstEdges::empty() const {
                return it == ite;
            }

            std::size_t ConstEdges::size() const {
                return std::distance(it, ite);
            }
        }
        
        Automaton::Automaton(std::string const& name) : name(name) {
            // Add a sentinel element to the mapping from locations to starting indices.
            locationToStartingIndex.push_back(0);
        }
        
        std::string const& Automaton::getName() const {
            return name;
        }

        Variable const& Automaton::addVariable(Variable const &variable) {
            if (variable.isBooleanVariable()) {
                return addVariable(variable.asBooleanVariable());
            } else if (variable.isBoundedIntegerVariable()) {
                return addVariable(variable.asBoundedIntegerVariable());
            } else if (variable.isUnboundedIntegerVariable()) {
                return addVariable(variable.asUnboundedIntegerVariable());
            } else if (variable.isRealVariable()) {
                return addVariable(variable.asRealVariable());
            } else {
                STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Variable has invalid type.");
            }
        }
        
        BooleanVariable const& Automaton::addVariable(BooleanVariable const& variable) {
            return variables.addVariable(variable);
        }
        
        BoundedIntegerVariable const& Automaton::addVariable(BoundedIntegerVariable const& variable) {
            return variables.addVariable(variable);
        }

        UnboundedIntegerVariable const& Automaton::addVariable(UnboundedIntegerVariable const& variable) {
            return variables.addVariable(variable);
        }

        RealVariable const& Automaton::addVariable(RealVariable const& variable) {
            return variables.addVariable(variable);
        }

        VariableSet& Automaton::getVariables() {
            return variables;
        }

        VariableSet const& Automaton::getVariables() const {
            return variables;
        }
        
        bool Automaton::hasTransientVariable() const {
            return variables.hasTransientVariable();
        }
        
        bool Automaton::hasLocation(std::string const& name) const {
            return locationToIndex.find(name) != locationToIndex.end();
        }
        
        std::vector<Location> const& Automaton::getLocations() const {
            return locations;
        }

        std::vector<Location>& Automaton::getLocations() {
            return locations;
        }

        Location const& Automaton::getLocation(uint64_t index) const {
            return locations[index];
        }

        Location& Automaton::getLocation(uint64_t index) {
            return locations[index];
        }

        uint64_t Automaton::addLocation(Location const& location) {
            STORM_LOG_THROW(!this->hasLocation(location.getName()), storm::exceptions::WrongFormatException, "Cannot add location with name '" << location.getName() << "', because a location with this name already exists.");
            locationToIndex.emplace(location.getName(), locations.size());
            locations.push_back(location);
            locationToStartingIndex.push_back(edges.size());
            return locations.size() - 1;
        }

        uint64_t Automaton::getLocationIndex(std::string const& name) const {
            assert(hasLocation(name));
            return locationToIndex.at(name);
        }

        void Automaton::addInitialLocation(std::string const& name) {
            auto it = locationToIndex.find(name);
            STORM_LOG_THROW(it != locationToIndex.end(), storm::exceptions::InvalidArgumentException, "Cannot make unknown location '" << name << "' the initial location.");
            return addInitialLocation(it->second);
        }
        
        void Automaton::addInitialLocation(uint64_t index) {
            STORM_LOG_THROW(index < locations.size(), storm::exceptions::InvalidArgumentException, "Cannot make location with index " << index << " initial: out of bounds.");
            initialLocationIndices.insert(index);
        }
        
        std::set<uint64_t> const& Automaton::getInitialLocationIndices() const {
            return initialLocationIndices;
        }

        Automaton::Edges Automaton::getEdgesFromLocation(std::string const& name) {
            auto it = locationToIndex.find(name);
            STORM_LOG_THROW(it != locationToIndex.end(), storm::exceptions::InvalidArgumentException, "Cannot retrieve edges from unknown location '" << name << ".");
            return getEdgesFromLocation(it->second);
        }
        
        Automaton::Edges Automaton::getEdgesFromLocation(uint64_t index) {
            auto it = edges.begin();
            std::advance(it, locationToStartingIndex[index]);
            auto ite = edges.begin();
            std::advance(ite, locationToStartingIndex[index + 1]);
            return Edges(it, ite);
        }
        
        Automaton::ConstEdges Automaton::getEdgesFromLocation(std::string const& name) const {
            auto it = locationToIndex.find(name);
            STORM_LOG_THROW(it != locationToIndex.end(), storm::exceptions::InvalidArgumentException, "Cannot retrieve edges from unknown location '" << name << ".");
            return getEdgesFromLocation(it->second);
        }
        
        Automaton::ConstEdges Automaton::getEdgesFromLocation(uint64_t index) const {
            auto it = edges.begin();
            std::advance(it, locationToStartingIndex[index]);
            auto ite = edges.begin();
            std::advance(ite, locationToStartingIndex[index + 1]);
            return ConstEdges(it, ite);
        }
        
        Automaton::Edges Automaton::getEdgesFromLocation(uint64_t locationIndex, uint64_t actionIndex) {
            typedef std::vector<Edge>::iterator ForwardIt;

            // Perform binary search for start of edges with the given action index.
            auto first = edges.begin();
            std::advance(first, locationToStartingIndex[locationIndex]);
            auto last = edges.begin();
            std::advance(last, locationToStartingIndex[locationIndex + 1]);
            typename std::iterator_traits<ForwardIt>::difference_type count, step;
            count = std::distance(first, last);
            
            ForwardIt it1;
            while (count > 0) {
                it1 = first;
                step = count / 2;
                std::advance(it1, step);
                if (it1->getActionIndex() < actionIndex) {
                    first = ++it1;
                    count -= step + 1;
                }
                else {
                    count = step;
                }
            }
            it1 = first;

            // If there is no such edge, we can return now.
            if (it1 != last && it1->getActionIndex() > actionIndex) {
                return Edges(last, last);
            }
            
            // Otherwise, perform a binary search for the end of the edges with the given action index.
            count = std::distance(it1,last);
            
            ForwardIt it2;
            while (count > 0) {
                it2 = it1;
                step = count / 2;
                std::advance(it2, step);
                if (!(actionIndex < it2->getActionIndex())) {
                    first = ++it2;
                    count -= step + 1;
                } else count = step;
            }
            it2 = first;
            
            return Edges(it1, it2);
        }
        
        Automaton::ConstEdges Automaton::getEdgesFromLocation(uint64_t locationIndex, uint64_t actionIndex) const {
            typedef std::vector<Edge>::const_iterator ForwardIt;
            
            // Perform binary search for start of edges with the given action index.
            auto first = edges.begin();
            std::advance(first, locationToStartingIndex[locationIndex]);
            auto last = edges.begin();
            std::advance(last, locationToStartingIndex[locationIndex + 1]);
            typename std::iterator_traits<ForwardIt>::difference_type count, step;
            count = std::distance(first, last);
            
            ForwardIt it1;
            while (count > 0) {
                it1 = first;
                step = count / 2;
                std::advance(it1, step);
                if (it1->getActionIndex() < actionIndex) {
                    first = ++it1;
                    count -= step + 1;
                }
                else {
                    count = step;
                }
            }
            it1 = first;
            
            // If there is no such edge, we can return now.
            if (it1 != last && it1->getActionIndex() > actionIndex) {
                return ConstEdges(last, last);
            }
            
            // Otherwise, perform a binary search for the end of the edges with the given action index.
            count = std::distance(it1,last);
            
            ForwardIt it2;
            while (count > 0) {
                it2 = first;
                step = count / 2;
                std::advance(it2, step);
                if (!(actionIndex < it2->getActionIndex())) {
                    first = ++it2;
                    count -= step + 1;
                } else count = step;
            }
            it2 = first;
            
            return ConstEdges(it1, it2);
        }
        
        void Automaton::addEdge(Edge const& edge) {
            STORM_LOG_THROW(edge.getSourceLocationIndex() < locations.size(), storm::exceptions::InvalidArgumentException, "Cannot add edge with unknown source location index '" << edge.getSourceLocationIndex() << "'.");
            
            // Find the right position for the edge and insert it properly.
            auto posIt = edges.begin();
            std::advance(posIt, locationToStartingIndex[edge.getSourceLocationIndex() + 1]);
            edges.insert(posIt, edge);
            
            // Now update the starting indices of all subsequent locations.
            for (uint64_t locationIndex = edge.getSourceLocationIndex() + 1; locationIndex < locationToStartingIndex.size(); ++locationIndex) {
                ++locationToStartingIndex[locationIndex];
            }
            
            // Sort all edges form the source location of the newly introduced edge by their action indices.
            auto it = edges.begin();
            std::advance(it, locationToStartingIndex[edge.getSourceLocationIndex()]);
            auto ite = edges.begin();
            std::advance(ite, locationToStartingIndex[edge.getSourceLocationIndex() + 1]);
            std::sort(it, ite, [] (Edge const& a, Edge const& b) { return a.getActionIndex() < b.getActionIndex(); } );
            
            // Update the set of action indices of this automaton.
            actionIndices.insert(edge.getActionIndex());
        }
        
        std::vector<Edge>& Automaton::getEdges() {
            return edges;
        }
        
        std::vector<Edge> const& Automaton::getEdges() const {
            return edges;
        }
        
        std::set<uint64_t> Automaton::getActionIndices() const {
            std::set<uint64_t> result;
            for (auto const& edge : edges) {
                result.insert(edge.getActionIndex());
            }
            return result;
        }
        
        uint64_t Automaton::getNumberOfLocations() const {
            return locations.size();
        }
        
        uint64_t Automaton::getNumberOfEdges() const {
            return edges.size();
        }

        bool Automaton::hasInitialStatesRestriction() const {
            return initialStatesRestriction.isInitialized();
        }
        
        storm::expressions::Expression const& Automaton::getInitialStatesRestriction() const {
            return initialStatesRestriction;
        }
        
        void Automaton::setInitialStatesRestriction(storm::expressions::Expression const& initialStatesRestriction) {
            this->initialStatesRestriction = initialStatesRestriction;
        }
        
        storm::expressions::Expression Automaton::getInitialStatesExpression() const {
            storm::expressions::Expression result;
            
            // Add initial state restriction if there is one.
            if (this->hasInitialStatesRestriction()) {
                result = this->getInitialStatesRestriction();
            }
            
            // Add the expressions for all variables that have initial expressions.
            for (auto const& variable : this->getVariables()) {
                if (variable.hasInitExpression()) {
                    storm::expressions::Expression newInitExpression = variable.isBooleanVariable() ? storm::expressions::iff(variable.getExpressionVariable(), variable.getInitExpression()) : variable.getExpressionVariable() == variable.getInitExpression();
                    if (result.isInitialized()) {
                        result = result && newInitExpression;
                    } else {
                        result = newInitExpression;
                    }
                }
            }
            
            return result;
        }
        
        bool Automaton::hasEdgeLabeledWithActionIndex(uint64_t actionIndex) const {
            return actionIndices.find(actionIndex) != actionIndices.end();
        }
        
        std::vector<storm::expressions::Expression> Automaton::getAllRangeExpressions() const {
            std::vector<storm::expressions::Expression> result;
            for (auto const& variable : this->getVariables().getBoundedIntegerVariables()) {
                result.push_back(variable.getRangeExpression());
            }
            return result;
        }
        
        void Automaton::substitute(std::map<storm::expressions::Variable, storm::expressions::Expression> const& substitution) {
            for (auto& variable : this->getVariables().getBoundedIntegerVariables()) {
                variable.substitute(substitution);
            }
            
            for (auto& location : this->getLocations()) {
                location.substitute(substitution);
            }
            
            this->setInitialStatesRestriction(this->getInitialStatesRestriction().substitute(substitution));
            
            for (auto& edge : this->getEdges()) {
                edge.substitute(substitution);
            }
        }
        
        void Automaton::finalize(Model const& containingModel) {
            for (auto& edge : edges) {
                edge.finalize(containingModel);
            }
        }

    }
}