import serial
import time
import sys
import os
from tabulate import tabulate
from scope import Scope
from collections import deque
# Serial port settings
SERIAL_PORT = "/dev/ttyUSB0"
BAUD_RATE = 115200
TIMEOUT = 2
# Glitch configuration limits
INCREMENT_LEN = 1
LOWER_GLITCH_LEN = 41
UPPER_GLITCH_LEN = 60
INCREMENT_REPEAT = 1
LOWER_GLITCH_REPEAT = 1
UPPER_GLITCH_REPEAT = 30
INCREMENT_DELAY = 5
LOWER_DELAY_TIME = 0
UPPER_DELAY_TIME = 30
message_history = deque(maxlen=20)
def restartDeviceAndChall():
    s.io.add(0, 0, delay=20000000)
    s.io.add(0, 1, delay=20000000)
    s.io.upload()
    s.trigger()
    time.sleep(5)
    
    s.io.add(1, 1, delay=30000000)
    s.io.add(1, 0, delay=30000000)
    s.io.upload()
    s.trigger()
    time.sleep(5)
def restartChall():
    s.io.add(1, 1, delay=30000000)
    s.io.add(1, 0, delay=30000000)
    s.io.upload()
    s.trigger()
    time.sleep(5)
def clear_console():
    os.system('cls' if os.name == 'nt' else 'clear')
def store_message(message):
    message_history.append(message)
def print_info(message):
    store_message(f"[INFO] {message}")
def print_warning(message):
    store_message(f"[WARNING] {message}")
def print_error(message):
    store_message(f"[ERROR] {message}")
def print_last_x_messages(x):
    # Ensure we print only the last 'x' messages, or all if less than x
    for msg in list(message_history)[-x:]:
        print(msg)
def format_elapsed_time(seconds):
    days = seconds // (24 * 3600)
    hours = (seconds % (24 * 3600)) // 3600
    minutes = (seconds % 3600) // 60
    seconds = seconds % 60
    return f"{int(days)}d {int(hours)}h {int(minutes)}m {int(seconds)}s"
def print_table(
glitch_len, trigger_repeats, delay_time, elapsed_time):
    headers = ["Glitch Len", "Repeats", "Delay", "Elapsed Time"]
    data = [[f"{glitch_len} / {UPPER_GLITCH_LEN}", 
             f"{trigger_repeats} / {UPPER_GLITCH_REPEAT}", 
             f"{delay_time} / {UPPER_DELAY_TIME}", 
             elapsed_time]]
    clear_console()
    print_banner()
    print(tabulate(data, headers=headers, tablefmt="fancy_grid"))
def connect_serial():
    try:
        ser = serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=TIMEOUT)
        print_info("Connected to serial port.")
        
        time.sleep(1)
        ser.flushInput()
        version_info = ser.readline().decode("utf-8", errors="ignore").strip()
        if version_info:
            print_info(f"Connected to version: {version_info}")
        else:
            print_warning("No version information received. Device might be unresponsive.")
        return ser
    except serial.SerialException as e:
        print_error(f"Could not open serial port {SERIAL_PORT}: {e}")
        sys.exit(1)
def read_serial():
    global glitch_len, trigger_repeats, delay_time
    global UPPER_GLITCH_LEN, LOWER_GLITCH_LEN, INCREMENT_LEN
    global UPPER_GLITCH_REPEAT, LOWER_GLITCH_REPEAT, INCREMENT_REPEAT
    global UPPER_DELAY_TIME, LOWER_DELAY_TIME, INCREMENT_DELAY
    
    glitch_len = LOWER_GLITCH_LEN
    trigger_repeats = LOWER_GLITCH_REPEAT
    delay_time = LOWER_DELAY_TIME
    start_time = time.time()
    last_restart_time = time.time()
    restart_interval = 600 # 600 secs = 10 min
    ser = connect_serial()
    print_info("Restarting device and challenge")
    #restartDeviceAndChall()
    
    while True:
        cycle_start_time = time.time()
        line = ""
        while time.time() - cycle_start_time < TIMEOUT:
            if ser.in_waiting > 0:
                char = ser.read().decode("utf-8", errors="ignore")
                line += char
                if char == '\n':
                    break
        if line:
            line = line.strip()
            print_info(f"Received data: {line}")
            line_lower = line.lower()
            if "ctf" in line_lower:
                print_info(f"Flag: {line}")
                print_warning("Received 'ctf', exiting...")
                print_table(glitch_len, trigger_repeats, delay_time, elapsed_time)
                print_last_x_messages(10)
                sys.exit()
            
            elif "starting challenge" in line_lower:
                print_info("Detected 'Starting Challenge'. Resetting glitch repeat count.")
                if (glitch_len + trigger_repeats) % 2 == 0:
                    #UPPER_GLITCH_REPEAT = trigger_repeats
                    trigger_repeats -= INCREMENT_REPEAT
                    glitch_len += INCREMENT_LEN
                else:
                    UPPER_GLITCH_LEN = glitch_len
                    glitch_len -= INCREMENT_LEN
                    trigger_repeats += INCREMENT_REPEAT
            elif "hold" in line_lower:
                print_info("Detected 'Hold'. Restarting Challenge.")
                #restartChall()
            else:
                execute_scope_script(glitch_len, trigger_repeats, delay_time)
                 # 1. Increment delay time first
                if delay_time < UPPER_DELAY_TIME:
                    delay_time += INCREMENT_DELAY
                else:
                    delay_time = LOWER_DELAY_TIME  # Reset delay time
                    # 2. Alternate between glitch_len and trigger_repeats when delay resets
                    if (glitch_len + trigger_repeats) % 2 == 0:
                        if glitch_len < UPPER_GLITCH_LEN:
                            glitch_len += INCREMENT_LEN
                            print_info(f"Incrementing glitch length: {glitch_len}")
                        elif (glitch_len > UPPER_GLITCH_LEN):
                            UPPER_GLITCH_LEN += INCREMENT_LEN
                        else:
                            glitch_len = LOWER_GLITCH_LEN  # Reset glitch length
                            trigger_repeats += INCREMENT_REPEAT  # Increment trigger repeats
                            print_info(f"Glitch length reset. Incrementing trigger repeats: {trigger_repeats}")
                    else:
                        if trigger_repeats < UPPER_GLITCH_REPEAT:
                            trigger_repeats += INCREMENT_REPEAT
                            print_info(f"Incrementing trigger repeats: {trigger_repeats}")
                        else:
                            if(glitch_len == UPPER_GLITCH_LEN):
                                trigger_repeats = LOWER_GLITCH_REPEAT  # Reset trigger repeats
                            else:
                                trigger_repeats = LOWER_GLITCH_REPEAT
                                glitch_len += INCREMENT_LEN  # Increment glitch length
                            print_info(f"Trigger repeats reset. Incrementing glitch length: {glitch_len}")
                elapsed_time = format_elapsed_time(time.time() - start_time)
                print_table(glitch_len, trigger_repeats, delay_time, elapsed_time)
                # To print the last 3 messages
                print_last_x_messages(10)
                if time.time() - last_restart_time >= restart_interval:
                    print_info(f"10 min mark. Setting glitch length: 80 and restarting device")
                    UPPER_GLITCH_LEN = 80
                    #restartDeviceAndChall()
                    last_restart_time = time.time()
def execute_scope_script(glitch_len, trigger_repeats, delay):
    s.glitch.repeat = glitch_len
    s.glitch.ext_offset = delay
    for _ in range(trigger_repeats):
        s.trigger()
def print_banner():
    print("   ___ _ _ _      _                          _   _    ")
    print("  / __| (_) |_ __| |_ ___ ___ ___ _ __  __ _| |_(_)__ ")
    print(" | (_ | | |  _/ _| ' \\___/ _ \\___| '  \\/ _` |  _| / _|")
    print("  \\___|_|_|\\__\\__|_||_|  \\___/   |_|_|_\\__,_|\\__|_\\__|")
if __name__ == "__main__":
    print_info("Starting Program")
    s = Scope()
    read_serial()