######
# LEAVE THESE IMPORTS!
######
import functions
import random
import requests
import threading
import time
from datetime import datetime
import os
import re
import json
import csv
from textual.widgets import Log

import sys
sys.path.append("/tmp")

from plotter import Plotter
from plotterStepperMakerBackend import PlotterBackend

import matplotlib
matplotlib.use("Agg")  # ensures headless backend (no GUI needed)
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors

######
# config values (you can edit these to fit your environment and use case)
######

# Serial port settings
SERIAL_PORT = "/dev/ttyUSB1"
BAUD_RATE = 9800

# plotter config file
plotter_port="/dev/ttyUSB0"
plotter_save_file = f"/tmp/plotter_test.json"
plotter = None  # Global plotter variable
plotterBackend = None
plotter_total_steps = None

# faultyCat variables
REPEAT = 1
DELAY = 0

# output file variable
results_file = f"/tmp/glitching_results.csv"
results_png = "/tmp/glitching_results.png"

###
# name, enabled, string to match in output, function to run
# if string is blank ("") doesnt show toggle, just run button
###
conditions = [
    ["TgOff", False, "", "control_device_off"],
    ["TgOn", False, "", "control_device_on"],
    ["Pass", False, "", "send_password"],
    ["RdBuf", False, "", "read_buffer"],
    ["CtSts", False, "WillNeverMatch01", "faulty_arm_disarm"],
    ["CtZap", False, "", "faulty_pulse"],
    ["Plt01", False, "", "setup_plotter"],
    ["PlSta", False, "", "plotter_move_start"],
    ["PlEnd", False, "", "plotter_move_end"],
    ["Run", False, "", "start_testing"],
]

results = []

###
# Generic Functions
###

def send_password():
    functions.send_uart_message("test")  
    functions.flush_uart_buffer()
    functions.add_text("[sending password]")    

def read_buffer():
    buffer = functions.read_uart_buffer()
    functions.add_text("[buffer]: " + buffer)

def check_buffer():
    buf = functions.read_uart_buffer()
    if "Incorrect password: test" in buf:
        return 0
    if re.search(r"Generated password: [A-Za-z0-9]{8}\b", buf):
        return 1
    return 2

###
# ESP32 Power Supply Relay
###

def control_device_on():
    functions.add_text("[turning on device]")
    try:
        return requests.get(f"http://192.168.0.122/off", timeout=5).text
    except Exception as e:
        return f"Error: {e}"
      
def control_device_off():
    functions.add_text("[turning off device]") 
    try:
        return requests.get(f"http://192.168.0.122/on", timeout=5).text
    except Exception as e:
        return f"Error: {e}"

###
# faultyCat
###

def faulty_arm_disarm():
    TriggersStatus = functions.get_condition_value(4)
    if TriggersStatus is True:
        functions.add_text("[FaultyCat disarming]")
        functions.faulty_disarm()
        functions.set_condition_value(4, False) 
    else:
        functions.add_text("[FaultyCat arming]")
        if functions.faulty_connect() and functions.faulty_arm():
            functions.set_condition_value(4, True)  

def faulty_pulse():
    functions.faulty_send_pulse()
    functions.add_text("[FaultyCat sending pulse]")

    # Start a background thread that waits 1 second before disarming
    def delayed_disarm():
        time.sleep(1)
        faulty_arm_disarm()

    threading.Thread(target=delayed_disarm, daemon=True).start()

def delayed_faulty_pulse(delay_ms):
    time.sleep(delay_ms / 1000.0)  # convert milliseconds to seconds
    faulty_pulse()

###
# plotter 
###

def setup_plotter():
    global plotter, plotterBackend, plotter_port, plotter_total_steps
    plotter = Plotter(port=plotter_port)
    plotterBackend = PlotterBackend(plotter)

    # Load state if available
    if os.path.exists(plotter_save_file):
        with open(plotter_save_file, "r") as f:
            state = json.load(f)
            plotterBackend.load_state(state)
            plotter_total_steps = plotterBackend.get_total_steps()
            functions.add_text(f"[total steps] - {plotterBackend.get_total_steps()}") 
            functions.add_text(f"[first place] - {plotterBackend.get_step_coords(1)}")
            functions.add_text(f"[last place] - {plotterBackend.get_step_coords(plotter_total_steps)}") # {'x': 2.0, 'y': 2.0, 'z': 40.0}
            functions.add_text(f"[cur place] - {plotterBackend.get_position()}")

def plotter_move_start():
    global plotter, plotterBackend
    if not plotter:
        functions.add_text("[ERROR] Plotter not initialised.")
        return
    try:
        cur = plotterBackend.get_position()
        target = plotterBackend.get_step_coords(1)
        functions.add_text(f"[moving to] - {target}")
        if target:
            dx = float(target["x"]) - float(cur["x"])
            dy = float(target["y"]) - float(cur["y"])
            if abs(dx) > 1e-9 or abs(dy) > 1e-9:
                plotter.move(x=dx, y=dy, feed=1000)
            plotter.set_spindle(speed=target.get("z", 0.0))
    except Exception as e:
        functions.add_text(f"[ERROR] plotter_move_start failed: {e}")

def plotter_move_end():
    global plotter, plotterBackend
    if not plotter:
        functions.add_text("[ERROR] Plotter not initialised.")
        return
    try:
        total_steps = plotterBackend.get_total_steps()
        target = plotterBackend.get_step_coords(total_steps)
        functions.add_text(f"[moving to] - {target}")
        if target:
            cur = plotterBackend.get_position()
            dx = float(target["x"]) - float(cur["x"])
            dy = float(target["y"]) - float(cur["y"])
            if abs(dx) > 1e-9 or abs(dy) > 1e-9:
                plotter.move(x=dx, y=dy, feed=1000)
            plotter.set_spindle(speed=target.get("z", 0.0))
    except Exception as e:
        functions.add_text(f"[ERROR] plotter_move_end failed: {e}")

def plotter_move_loc(position=1):
    global plotter, plotterBackend, plotter_total_steps
    if not plotter:
        functions.add_text("[ERROR] Plotter not initialised.")
        return
    try:
        target = plotterBackend.get_step_coords(position)
        functions.add_text(f"[moving to] - {target}")
        if target:
            cur = plotterBackend.get_position()
            dx = float(target["x"]) - float(cur["x"])
            dy = float(target["y"]) - float(cur["y"])
            if abs(dx) > 1e-9 or abs(dy) > 1e-9:
                plotter.move(x=dx, y=dy, feed=1000)
            plotter.set_spindle(speed=target.get("z", 0.0))
    except Exception as e:
        functions.add_text(f"[ERROR] plotter_move_loc failed: {e}")

###
# main program
###
def start_testing():
    def worker():
        try:
            global plotter_total_steps, results
            setup_plotter()
            loop = int(functions.get_config_value("repeat"))
            time.sleep(1)

            # Initialise results array with a dummy at index 0
            results = [None] + [{"nothing": 0, "crash": 0, "glitch": 0, "x": 0.0, "y": 0.0, "z": 0.0} for _ in range(plotter_total_steps)]

            start_time = datetime.now()
            functions.add_text(f"[start time] {start_time.strftime('%d/%m/%Y %H:%M')}")

            for i in range(loop):
                for currentStep in range(1, plotter_total_steps + 1):
                    functions.add_text(f"[stats] loop {i +1}/{loop}, place {currentStep}/{plotter_total_steps}")

                    # Get target coordinates
                    target = plotterBackend.get_step_coords(currentStep)
                    
                    # Store coordinates in results
                    results[currentStep]["x"] = target.get("x", 0.0)
                    results[currentStep]["y"] = target.get("y", 0.0)
                    results[currentStep]["z"] = target.get("z", 0.0)

                    plotter_move_loc(currentStep)

                    control_device_off()
                    functions.flush_uart_buffer()
                    time.sleep(1)

                    control_device_on()
                    time.sleep(1)

                    functions.add_text("[FaultyCat arming]")
                    if functions.faulty_connect() and functions.faulty_arm():
                        functions.set_condition_value(4, True)

                    time.sleep(1)
                    curDelay = functions.get_config_value("delay")

                    functions.flush_uart_buffer()
                    threading.Thread(target=send_password, daemon=True).start()
                    threading.Thread(target=delayed_faulty_pulse, args=(curDelay,), daemon=True).start()

                    time.sleep(2)
                    result = check_buffer()
                    functions.add_text(f"[result] {result}")

                    # Update result counters
                    if result == 0:
                        results[currentStep]["nothing"] += 1
                    elif result == 1:
                        results[currentStep]["crash"] += 1
                    else:
                        results[currentStep]["glitch"] += 1

            end_time = datetime.now()
            functions.add_text(f"[end time] {end_time.strftime('%d/%m/%Y %H:%M')}")

            elapsed = end_time - start_time
            # Format elapsed time including days
            total_seconds = int(elapsed.total_seconds())
            days, remainder = divmod(total_seconds, 86400)  # 86400 seconds in a day
            hours, remainder = divmod(remainder, 3600)
            minutes, seconds = divmod(remainder, 60)

            if days > 0:
                functions.add_text(f"[time elapsed] {days}d {hours:02d}:{minutes:02d}:{seconds:02d}")
            else:
                functions.add_text(f"[time elapsed] {hours:02d}:{minutes:02d}:{seconds:02d}")

            save_results_to_csv()
            print_results()
            generate_ascii_map()
            #find_best_glitch_spot()
            save_results_map_png()

        except Exception as e:
            functions.add_text(f"[error] Exception in start_testing: {e}")

    threading.Thread(target=worker, daemon=True).start()

###
# results manipulation
###
def save_results_to_csv():
    global results, results_file
    
    if not results or len(results) <= 1:
        functions.add_text("[error] No results to save.")
        return
    
    try:
        with open(results_file, mode="w", newline="") as csvfile:
            fieldnames = ["Step", "Nothing", "Crash", "Glitch", "X", "Y", "Z"]
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

            writer.writeheader()
            for step in range(1, len(results)):
                data = results[step]
                writer.writerow({
                    "Step": step,
                    "Nothing": data["nothing"],
                    "Crash": data["crash"],
                    "Glitch": data["glitch"],
                    "X": data["x"],
                    "Y": data["y"],
                    "Z": data["z"]
                })
        functions.add_text(f"[results saved] {results_file}")
    except Exception as e:
        functions.add_text(f"[error] Failed to save results: {e}")

def print_results():
    global results
    for step in range(1, len(results)):
        data = results[step]
        functions.add_text(
            f"[step {step}] nothing={data['nothing']} crash={data['crash']} glitch={data['glitch']} "
            f"x={data['x']} y={data['y']} z={data['z']}"
        )

def generate_ascii_map():
    global results

    if not results or len(results) <= 1:
        functions.add_text("[error] No results to map.")
        return

    # Collect unique sorted coordinates
    xs = sorted(set(int(results[i]["x"]) for i in range(1, len(results))))
    ys = sorted(set(int(results[i]["y"]) for i in range(1, len(results))), reverse=True)  # top to bottom

    width = len(xs)
    height = len(ys)

    # Map coordinates to grid indices
    x_to_idx = {x: idx for idx, x in enumerate(xs)}
    y_to_idx = {y: idx for idx, y in enumerate(ys)}

    # Create empty grid
    grid = [["" for _ in range(width)] for _ in range(height)]

    # Fill grid
    for step in range(1, len(results)):
        data = results[step]
        gx = x_to_idx[int(data["x"])]
        gy = y_to_idx[int(data["y"])]

        if data["glitch"] > 0:
            symbol = str(min(data["glitch"], 9))
        elif data["crash"] > 0:
            symbol = "-"
        elif data["nothing"] > 0:
            symbol = "."
        else:
            symbol = " "

        grid[gy][gx] = symbol

    # Convert grid to string
    ascii_map = "\n".join("".join(row) for row in grid)
    functions.add_text("[ASCII map]")
    functions.add_text(ascii_map)

def find_best_glitch_spot():
    """
    Return the step number most likely to cause a glitch.
    Scoring includes:
    - Weighted glitches (positive)
    - Crashes (penalty)
    - Nothing results (penalty)
    - Edge/out-of-bounds treated as nothing (penalty)
    """

    global results

    if not results or len(results) <= 1:
        functions.add_text("[error] No results available.")
        return None

    # Collect unique coordinates
    xs = sorted(set(float(results[i]["x"]) for i in range(1, len(results))))
    ys = sorted(set(float(results[i]["y"]) for i in range(1, len(results))))

    # Compute minimal spacing
    def _min_spacing(arr):
        if len(arr) < 2:
            return 1.0
        diffs = [round(arr[i+1] - arr[i], 8) for i in range(len(arr) - 1)]
        diffs = [d for d in diffs if d > 1e-8]
        return min(diffs) if diffs else 1.0

    base = min(_min_spacing(xs), _min_spacing(ys)) or 1.0
    scale = int(round(1.0 / base)) if base > 0 else 1

    # Map coordinates to step index
    coord_to_step = {}
    for i in range(1, len(results)):
        gx = int(round(float(results[i]["x"]) * scale))
        gy = int(round(float(results[i]["y"]) * scale))
        coord_to_step[(gx, gy)] = i

    # Neighbourhood weights
    centre_w = 4
    orth_w = 2
    diag_w = 1

    # Penalties
    crash_penalty = 2
    nothing_penalty = 1

    best_step = None
    best_score = -9999

    for step in range(1, len(results)):
        centre_glitches = int(results[step]["glitch"])
        if centre_glitches == 0:
            continue  # only consider glitch spots

        gx = int(round(float(results[step]["x"]) * scale))
        gy = int(round(float(results[step]["y"]) * scale))

        score = 0

        for dx in (-1, 0, 1):
            for dy in (-1, 0, 1):
                nx, ny = gx + dx, gy + dy
                w = centre_w if dx == 0 and dy == 0 else (orth_w if dx == 0 or dy == 0 else diag_w)
                idx = coord_to_step.get((nx, ny))

                if idx is not None:
                    glitches = int(results[idx]["glitch"])
                    crashes = int(results[idx]["crash"])
                    nothing = int(results[idx]["nothing"])
                    score += glitches * w
                    score -= crashes * crash_penalty
                    score -= nothing * nothing_penalty
                else:
                    # Edge: treat as "nothing"
                    score -= nothing_penalty * w

        # Tie breaking rules
        if (score > best_score or
            (score == best_score and centre_glitches > int(results[best_step]["glitch"])) or
            (score == best_score and centre_glitches == int(results[best_step]["glitch"]) and step < best_step)):
            best_score = score
            best_step = step

    if best_step:
        functions.add_text(
            f"[best glitch] position {best_step} glitches={results[best_step]['glitch']} "
            f"crashes={results[best_step]['crash']} nothing={results[best_step]['nothing']} "
            f"x={results[best_step]['x']} y={results[best_step]['y']} z={results[best_step]['z']} score={best_score}"
        )
        return best_step

    functions.add_text("[best glitch] No glitch-prone positions found.")
    return None

def save_results_map_png():
    global results

    if not results or len(results) <= 1:
        functions.add_text("[error] No results to plot.")
        return

    try:
        # Collect sorted unique coordinates
        xs = sorted(set(int(results[i]["x"]) for i in range(1, len(results))))
        ys = sorted(set(int(results[i]["y"]) for i in range(1, len(results))), reverse=True)

        width = len(xs)
        height = len(ys)

        x_to_idx = {x: idx for idx, x in enumerate(xs)}
        y_to_idx = {y: idx for idx, y in enumerate(ys)}

        fig, ax = plt.subplots(figsize=(width, height), dpi=100)
        ax.set_xlim(0, width)
        ax.set_ylim(0, height)
        ax.set_aspect("equal")
        ax.axis("off")

        # Flip y-axis so that top-left in ASCII corresponds to bottom-left in PNG
        ax.invert_yaxis()

        # Colour map for glitches (from #025218 to #6eff96)
        glitch_cmap = mcolors.LinearSegmentedColormap.from_list("glitch_cmap", ["#025218", "#6eff96"])

        # Draw squares
        for step in range(1, len(results)):
            data = results[step]
            gx = x_to_idx[int(data["x"])]
            gy = y_to_idx[int(data["y"])]

            if data["glitch"] > 0:
                level = min(data["glitch"], 9) / 9.0
                color = glitch_cmap(level)
            elif data["crash"] > 0:
                color = "#ff0000"  # red
            elif data["nothing"] > 0:
                color = "#3d3d3d"  # grey
            else:
                color = "#000000"

            rect = patches.Rectangle(
                (gx, gy), 1, 1, facecolor=color, edgecolor="black", linewidth=0.2
            )
            ax.add_patch(rect)

        # Highlight the best glitch spot with X
        best_step = find_best_glitch_spot()
        if best_step:
            bx = x_to_idx[int(results[best_step]["x"])]
            by = y_to_idx[int(results[best_step]["y"])]

            ax.plot([bx, bx + 1], [by, by + 1], color="black", linewidth=1.5)
            ax.plot([bx, bx + 1], [by + 1, by], color="black", linewidth=1.5)

        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.savefig(results_png, bbox_inches="tight", pad_inches=0)
        plt.close(fig)  # ensure figure is released

        functions.add_text(f"[results map saved] {results_png}")

    except Exception as e:
        functions.add_text(f"[error] Failed to generate map: {e}")