#!/usr/bin/python3

import re, sys, os, shutil, fileinput, subprocess, argparse
from dataclasses import dataclass, field
import visvis as vv
import numpy as np
import argparse

from translate import translateTransitions, readLabels
from plotting import VisVisPlotter
from simulation import Simulator, Verdict


def convert(tuples):
    return dict(tuples)

def getBasename(filename):
    return os.path.basename(filename)

def traFileWithIteration(filename, iteration):
    return os.path.splitext(filename)[0] + f"_{iteration:03}.tra"

def copyFile(filename, newFilename):
    shutil.copy(filename, newFilename)


def execute(command, verbose=False):
    if verbose: print(f"Executing {command}")
    os.system(command)

@dataclass(frozen=True)
class State:
    id: int
    x: float
    x_vel: float
    y: float
    y_vel: float
    z: float
    z_vel: float

def default_value():
    return {'action' : None, 'choiceValue' : None}


@dataclass(frozen=True)
class StateValue:
    ranking: float
    choices: dict = field(default_factory=default_value)

@dataclass(frozen=False)
class TestResult:
    prob_pes_min: float
    prob_pes_max: float
    prob_pes_avg: float
    prob_opt_min: float
    prob_opt_max: float
    prob_opt_avg: float
    min_min: float
    min_max: float

    def csv(self, ws=" "):
        return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}"

def parseStrategy(strategyFile, allStateActionPairs, time_index=0):
    strategy = dict()
    with open(strategyFile) as strategyLines:
        for line in strategyLines:
            line = line.replace("(","").replace(")","").replace("\n", "")
            explode = re.split(",|=", line)
            stateId = int(explode[0]) + 3
            if stateId < 3: continue
            if int(explode[1]) != time_index: continue
            try:
                strategy[stateId] = allStateActionPairs[stateId].index(explode[2])
            except KeyError as e:
                pass

    return strategy

def queryStrategy(strategy, stateId):
    try:
        return strategy[stateId]
    except:
        return -1


def callTempest(files, reward, bound=3):
    property_str = "!(\"failed\" | \"reached\")"
    if True:
        prop =  f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
        prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
        prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
        prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
        prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
        prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
        prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str}  );"
        prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str}  );"
    else:
        prop =  f"filter(min, Rmin=? [ C<={bound} ], {property_str}  );"
        prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str}  );"
        prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str}  );"
        prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str}  );"
        prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str}  );"
        prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str}  );"
    command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' "

    results = list()
    try:
        output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
        for line in output:
            if "Result" in line and not len(results) >= 10:
                range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
                if range_value:
                    results.append(float(range_value.group(2)))
                    results.append(float(range_value.group(3)))
                else:
                    value = re.search(r"(.*:)(.*)", line)
                    results.append(float(value.group(2)))
    except subprocess.CalledProcessError as e:
        print(e.output)
    #results.append(-1)
    #results.append(-1)
    return TestResult(*(tuple(results)))

def parseRanking(filename, allStates):
    state_ranking = dict()
    try:
        with open(filename, "r") as f:
            filecontent = f.readlines()
        for line in filecontent:
            stateId = int(re.findall(r"^\d+", line)[0])
            values = re.findall(r":(-?\d+\.?\d*),?", line)
            ranking_value = float(values[0])
            choices = {i : float(value) for i,value in enumerate(values[1:])}
            state = allStates[stateId]
            value = StateValue(ranking_value, choices)
            state_ranking[state] = value
        if len(state_ranking) == 0: return
        all_values = [x.ranking for x in state_ranking.values()]
        max_value = max(all_values)
        min_value = min(all_values)
        new_state_ranking = {}
        for state, value in state_ranking.items():
            choices = value.choices
            try:
                new_value = (value.ranking - min_value) / (max_value - min_value)
            except ZeroDivisionError as e:
                new_value = 0.0
            new_state_ranking[state] = StateValue(new_value, choices)
        state_ranking = new_state_ranking
    except EnvironmentError:
        print("TODO file not available. Exiting.")
        sys.exit(1)
    return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)}

def parseStateValuations(filename):
    all_states = dict()
    maxStateId = -1
    for i in [0,1,2]:
        dummy_values = [i] * 7
        all_states[i] = State(*dummy_values)
    with open(filename) as stateValuations:
        for line in stateValuations:
            values = re.findall(r"(-?\d+\.?\d*),?", line)
            values = [int(values[0])] + [float(v) for v in values[1:]]
            all_states[values[0]] = State(*values)
            if values[0] > maxStateId: maxStateId = values[0]
    dummy_values = [maxStateId + 1] * 7
    all_states[maxStateId + 1] = State(*dummy_values)
    return all_states

def parseResults(allStates):
    state_to_values = dict()
    with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer:
        for max_line, min_line in zip(maximizer, minimizer):
            max_values = re.findall(r"(-?\d+\.?\d*),?", max_line)
            min_values = re.findall(r"(-?\d+\.?\d*),?", min_line)
            if max_values[0] != min_values[0]:
                print("min/max files do not match.")
                assert(False)
            stateId = int(max_values[0])
            min_result = float(min_values[1])
            max_result = float(max_values[1])
            value = (min_result, max_result, max_result - min_result)
            state_to_values[stateId] = value
    return state_to_values


def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration):
    stateIdRegex = re.compile(f"^{stateId}\s")
    for line in fileinput.input(filename, inplace = True):
        if not stateIdRegex.match(line):
            print(line, end="")
        else:
            explode = line.split(" ")
            if int(explode[1]) == chosenActionIndex:
                print(line, end="")

def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration):
    stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()]))
    for line in fileinput.input(filename, inplace = True):
        result = stateIdsRegex.match(line)
        if not result:
            print(line, end="")
        else:
            actionIndex = stateActionPairsToTrim[int(result[0])]
            explode = line.split(" ")
            if int(explode[1]) == actionIndex:
                print(line, end="")

def getTopNStates(rankedStates, n, threshold):
    if n != 0:
        return dict(list(rankedStates.items())[-n:])
    else:
        return {state:value for state,value in rankedStates.items() if value.ranking >= threshold}

def getNRandomStates(rankedStates, n, testedStates):
    stateIds = [state.id for state in rankedStates.keys()]
    notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates])
    if len(notYetTestedStates) >= n:
        return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)]
    else:
        return notYetTestedStates

def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False):

    all_states = parseStateValuations("MDP_state_valuations")
    deadlockStates, reachedStates, maxStateId = readLabels(labFile)
    stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
    strategy = parseStrategy(straFile, stateToActions)

    if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting)
    if plotting: plotter.plotScenario()

    copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra")

    iteration = 0
    #testsPerIteration = refinementSteps
    #refinementThreshold =
    numTestedStates = 0
    totalIterations = 60

    testedStates = list()
    while iteration < totalIterations:
        print(f"{iteration:03}", end="\t")
        sys.stdout.flush()
        currentTraFile = traFileWithIteration("MDP_" + traFile, iteration)
        nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1)
        testResult = callTempest(f"{currentTraFile} MDP_{labFile}",  "saferew", horizonBound)
        state_ranking = parseRanking("action_ranking", all_states)
        copyFile("action_ranking", f"action_ranking_{iteration:03}")
        copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}")
        copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}")

        if not ablationTesting:
            importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound)
            statesToTest = [state.id for state in importantStates.keys()]
            statesToPlot = importantStates
        else:
            statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates))
            testedStates += statesToTest
            statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest}


        copyFile(currentTraFile, nextTraFile)
        stateActionPairsToTrim = dict()
        for testState in statesToTest:
            chosenActionIndex = queryStrategy(strategy, testState)
            if chosenActionIndex != -1:
                stateActionPairsToTrim[testState] = chosenActionIndex
        stateEstimates = parseResults(all_states)
        results = [0,0,0]

        failureStates = list()
        validatedStates = list()
        for state, estimates in stateEstimates.items():
            if state in deadlockStates or state in reachedStates:
                continue
            if estimates[0] > 0.05:
                results[1] += 1
                failureStates.append(all_states[state])
                #print(f"{state}: {estimates}")
            elif estimates[1] <= 0.05:
                results[0] += 1
                validatedStates.append(all_states[state])
            else:
                results[2] += 1


        removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration)
        print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}")
        if results[2] == 0:
            sys.exit(0)
        numTestedStates += len(statesToTest)
        iteration += 1

        if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True)
        if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6))
        if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05")

def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False):
    all_states = parseStateValuations("MDP_state_valuations")
    deadlockStates, reachedStates, maxStateId = readLabels(labFile)
    stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
    strategy = parseStrategy(straFile, stateToActions)

    if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting)
    if plotting: plotter.plotScenario()
    passingStates = list()

    randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound)
    i = 0
    print("Starting with random testing.")
    numQueries = 0
    failureStates = list()
    while numQueries <= maxQueries:
        if i >= 500:
            if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6))
            if plotting: plotter.takeScreenshot(iteration, prefix="random_testing")
            if plotting: plotter.turnCamera()
            if plotting: input("")
            print(f"{numQueries} {len(failureStates)} ")
            i = 0
        testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest()
        i += queriesForThisTestCase
        numQueries += queriesForThisTestCase
        stateValuation = all_states[testCase]
        if testResult == Verdict.FAIL:
            failureStates.append(stateValuation)

    print(f"{numQueries} {len(failureStates)} ")

def parseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--tra', type=str, required=True,  help='Path to .tra file.')
    parser.add_argument('--lab', type=str, required=True,  help='Path to .lab file.')
    parser.add_argument('--rew', type=str, required=True,  help='Path to .rew file.')
    parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.')

    refinement = parser.add_mutually_exclusive_group(required=True)
    refinement.add_argument('--refinement-steps', type=int,   default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.')
    refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.')

    parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.')
    parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.')

    random_testing = parser.add_mutually_exclusive_group()
    random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.")
    random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.')

    parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.')
    parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.')
    return parser.parse_args()

if __name__ == '__main__':
    args = parseArgs()

    traFile = args.tra
    labFile = args.lab
    straFile = args.stra
    rewFile = args.rew

    ablationTesting = args.ablation
    plotting = args.plotting
    stepwisePlotting = args.stepwise

    maxQueriesForRandomTesting = args.random
    horizonBound = args.bound

    refinementSteps = args.refinement_steps
    refinementBound = args.refinement_bound

    if maxQueriesForRandomTesting == 0: #akward way to test for this...
        main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting)
    else:
        randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting)