#include "gtest/gtest.h"
#include "storm-config.h"

#include "storm-dft/api/storm-dft.h"
#include "storm-parsers/api/storm-parsers.h"

namespace {

    // Configurations for DFT approximation
    struct DftAnalysisConfig {
        storm::builder::ApproximationHeuristic heuristic;
        bool useSR;
    };

    class ApproxDepthConfig {
    public:
        typedef double ValueType;
        static DftAnalysisConfig createConfig() {
            return DftAnalysisConfig {storm::builder::ApproximationHeuristic::DEPTH, false};
        }
    };
    class ApproxProbabilityConfig {
    public:
        typedef double ValueType;
        static DftAnalysisConfig createConfig() {
            return DftAnalysisConfig {storm::builder::ApproximationHeuristic::PROBABILITY, false};
        }
    };
    class ApproxBoundDifferenceConfig {
    public:
        typedef double ValueType;
        static DftAnalysisConfig createConfig() {
            return DftAnalysisConfig {storm::builder::ApproximationHeuristic::BOUNDDIFFERENCE, false};
        }
    };

    // General base class for testing of DFT approximation
    template<typename TestType>
    class DftApproximationTest : public ::testing::Test {
    public:
        typedef typename TestType::ValueType ValueType;

        DftApproximationTest() : config(TestType::createConfig()) {
        }

        DftApproximationTest const& getConfig() const {
            return config;
        }

        std::pair<double, double> analyzeMTTF(std::string const& file, double errorBound) {
            std::shared_ptr<storm::storage::DFT<double>> dft = storm::api::loadDFTGalileoFile<double>(file);
            EXPECT_TRUE(storm::api::isWellFormed(*dft));
            std::string property = "T=? [F \"failed\"]";
            std::vector<std::shared_ptr<storm::logic::Formula const>> properties = storm::api::extractFormulasFromProperties(storm::api::parseProperties(property));
            typename storm::modelchecker::DFTModelChecker<double>::dft_results results = storm::api::analyzeDFTApprox<double>(*dft, properties, config.useSR, false, true, errorBound, config.heuristic, false);
            return boost::get<storm::modelchecker::DFTModelChecker<double>::approximation_result>(results[0]);
        }

        std::pair<double, double> analyzeTimebound(std::string const& file, double timeBound, double errorBound) {
            std::shared_ptr<storm::storage::DFT<double>> dft = storm::api::loadDFTGalileoFile<double>(file);
            EXPECT_TRUE(storm::api::isWellFormed(*dft));
            std::stringstream propertyStream;
            propertyStream << "P=? [F<=" << timeBound << " \"failed\"]";
            std::vector<std::shared_ptr<storm::logic::Formula const>> properties = storm::api::extractFormulasFromProperties(storm::api::parseProperties(propertyStream.str()));
            typename storm::modelchecker::DFTModelChecker<double>::dft_results results = storm::api::analyzeDFTApprox<double>(*dft, properties, config.useSR, false, true, errorBound, config.heuristic, false);
            return boost::get<storm::modelchecker::DFTModelChecker<double>::approximation_result>(results[0]);
        }

    private:
        DftAnalysisConfig config;
    };

    typedef ::testing::Types<
            ApproxDepthConfig,
            ApproxProbabilityConfig,
            ApproxBoundDifferenceConfig
        > TestingTypes;

    TYPED_TEST_CASE(DftApproximationTest, TestingTypes);

    TYPED_TEST(DftApproximationTest, HecsMTTF) {
        double errorBound = 2;
        std::pair<double, double> approxResult = this->analyzeMTTF(STORM_TEST_RESOURCES_DIR "/dft/hecs_3_2_2_np.dft", errorBound);
        EXPECT_LE(approxResult.first, 417.9436693);
        EXPECT_GE(approxResult.second, 417.9436693);
        EXPECT_LE(2*(approxResult.second - approxResult.first) / (approxResult.first + approxResult.second), errorBound);
        // Ensure results are not equal -> not exact values were computed
        EXPECT_GE(approxResult.second - approxResult.first, errorBound * approxResult.first / 10);
    }

    TYPED_TEST(DftApproximationTest, HecsTimebound) {
        //double errorBound = 0.01;
        double errorBound = 0.5;
        double timeBound = 100;
        std::pair<double, double> approxResult = this->analyzeTimebound(STORM_TEST_RESOURCES_DIR "/dft/hecs_3_2_2_np.dft", timeBound, errorBound);
        EXPECT_LE(approxResult.first, 0.0410018417);
        EXPECT_GE(approxResult.second, 0.0410018417);
        EXPECT_LE(approxResult.second - approxResult.first, errorBound);
        // Ensure results are not equal -> not exact values were computed
        EXPECT_GE(approxResult.second - approxResult.first, errorBound / 10);
    }

}