#include "storm/simulator/PrismProgramSimulator.h"
#include "storm/exceptions/NotSupportedException.h"

using namespace storm::generator;

namespace storm {
    namespace simulator {

        template<typename ValueType>
        DiscreteTimePrismProgramSimulator<ValueType>::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options)
        : program(program), currentState(), stateGenerator(std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero<ValueType>()), lastActionRewards(zeroRewards)
        {
            // Current state needs to be overwritten to actual initial state.
            // But first, let us create a state generator.

            clearStateCaches();
            resetToInitial();
        }

        template<typename ValueType>
        void DiscreteTimePrismProgramSimulator<ValueType>::setSeed(uint64_t newSeed) {
            generator = storm::utility::RandomProbabilityGenerator<ValueType>(newSeed);
        }

        template<typename ValueType>
        bool DiscreteTimePrismProgramSimulator<ValueType>::step(uint64_t actionNumber) {
            uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random());
            lastActionRewards = behavior.getChoices()[actionNumber].getRewards();
            STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model.");
            currentState = idToState[nextState];
            // TODO we do not need to do this in every step!
            clearStateCaches();
            explore();
            return true;
        }

        template<typename ValueType>
        bool DiscreteTimePrismProgramSimulator<ValueType>::explore() {
            // Load the current state into the next state generator.
            stateGenerator->load(currentState);
            // TODO: This low-level code currently expands all actions, while this is not necessary.
            // However, using the next state generator ensures compatibliity with the model generator.
            behavior = stateGenerator->expand(stateToIdCallback);
            STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length.");            
            for(uint64_t i = 0; i < behavior.getStateRewards().size(); i++) {
                lastActionRewards[i] += behavior.getStateRewards()[i];
            }
            return true;
        }

        template<typename ValueType>
        bool DiscreteTimePrismProgramSimulator<ValueType>::isSinkState() const {
            if(behavior.empty()) {
                return true;
            }
            std::set<uint32_t> successorIds;
            for (Choice<ValueType,uint32_t> const& choice : behavior.getChoices()) {
                for (auto it = choice.begin(); it != choice.end(); ++it) {
                    successorIds.insert(it->first);
                    if (successorIds.size() > 1) {
                        return false;
                    }
                }
            }
            if (idToState.at(*(successorIds.begin())) == currentState) {
                return true;
            }
            return false;
        }

        template<typename ValueType>
        std::vector<generator::Choice<ValueType, uint32_t>> const& DiscreteTimePrismProgramSimulator<ValueType>::getChoices() const {
            return behavior.getChoices();
        }

        template<typename ValueType>
        std::vector<ValueType> const& DiscreteTimePrismProgramSimulator<ValueType>::getLastRewards() const {
            return lastActionRewards;
        }

        template<typename ValueType>
        CompressedState const& DiscreteTimePrismProgramSimulator<ValueType>::getCurrentState() const {
            return currentState;
        }

        template<typename ValueType>
        expressions::SimpleValuation DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateAsValuation() const {
            return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager());
        }

        template<typename ValueType>
        std::string DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateString() const {
            return stateGenerator->stateToString(currentState);
        }

        template<typename ValueType>
        storm::json<ValueType> DiscreteTimePrismProgramSimulator<ValueType>::getStateAsJson() const {
            return stateGenerator->currentStateToJson(false);
        }

        template<typename ValueType>
        storm::json<ValueType> DiscreteTimePrismProgramSimulator<ValueType>::getObservationAsJson() const {
            return stateGenerator->currentStateToJson(true);
        }

        template<typename ValueType>
        bool DiscreteTimePrismProgramSimulator<ValueType>::resetToInitial() {
            lastActionRewards = zeroRewards;
            auto indices = stateGenerator->getInitialStates(stateToIdCallback);
            STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state");
            currentState = idToState[indices[0]];
            return explore();
        }

        template<typename ValueType>
        bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(generator::CompressedState const& newState) {
            lastActionRewards = zeroRewards;
            currentState = newState;
            return explore();
        }

        template<typename ValueType>
        bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(expressions::SimpleValuation const& valuation) {
            currentState = generator::packStateFromValuation(valuation, stateGenerator->getVariableInformation(), true);
            return explore();
        }

        template<typename ValueType>
        std::vector<std::string> DiscreteTimePrismProgramSimulator<ValueType>::getRewardNames() const {
            std::vector<std::string> names;
            for (uint64_t i = 0; i < stateGenerator->getNumberOfRewardModels(); ++i) {
                names.push_back(stateGenerator->getRewardModelInformation(i).getName());
            }
            return names;
        }

        template<typename ValueType>
        uint32_t DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex(generator::CompressedState const& state) {
            uint32_t newIndex = static_cast<uint32_t>(stateToId.size());

            // Check, if the state was already registered.
            std::pair<uint32_t, std::size_t> actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex);

            uint32_t actualIndex = actualIndexBucketPair.first;
            if (actualIndex == newIndex) {
                idToState[actualIndex] = state;
            }
            return actualIndex;
        }

        template<typename ValueType>
        void DiscreteTimePrismProgramSimulator<ValueType>::clearStateCaches() {
            idToState.clear();
            stateToId = storm::storage::BitVectorHashMap<uint32_t>(stateGenerator->getStateSize());
        }

        template class DiscreteTimePrismProgramSimulator<double>;
    }
}