#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.9"
# dependencies = [
#     "matplotlib",
#     "numpy",
# ]
# ///

import subprocess
import sys
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


def run_benchmark(size: int, benchmark_file: str) -> float:
    """Run the benchmark with the given size and return execution time in seconds."""
    start_time = time.time()

    # Run: echo <size> | <benchmark_file>
    echo_process = subprocess.Popen(["echo", str(size)], stdout=subprocess.PIPE)

    if benchmark_file.endswith(".roc"):
        cmd = ["roc", benchmark_file]
    elif benchmark_file.endswith(".py"):
        cmd = ["python3", benchmark_file]
    else:
        raise ValueError(f"Unknown file type: {benchmark_file}")

    benchmark_process = subprocess.Popen(
        cmd,
        stdin=echo_process.stdout,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )

    echo_process.stdout.close()
    stdout, stderr = benchmark_process.communicate()

    end_time = time.time()
    duration = end_time - start_time

    if benchmark_process.returncode != 0:
        print(f"Error running benchmark with size {size}:")
        print(stderr.decode())
        sys.exit(1)

    return duration


def main():
    roc_file = "bench_fold_append.roc"
    py_file = "bench_fold_append.py"

    # Check if the files exist
    if not Path(roc_file).exists():
        print(f"Error: {roc_file} not found in current directory")
        sys.exit(1)
    if not Path(py_file).exists():
        print(f"Error: {py_file} not found in current directory")
        sys.exit(1)

    # Generate exponential sizes: 500, 1000, 2000, 4000, ..., up to ~50k
    sizes = []
    size = 50
    while size <= 10000:
        sizes.append(size)
        size *= 2

    print("Warming up ...")
    for i in enumerate(sizes, 1):
        duration = run_benchmark(1, roc_file)
        duration = run_benchmark(1, py_file)
    print("=" * 60)

    print("Running benchmarks...")
    print(f"Sizes to test: {sizes}")
    print("=" * 60)

    roc_results = []
    py_results = []

    for i, size in enumerate(sizes, 1):
        print(f"[{i}/{len(sizes)}] Testing size {size:>6}...", end=" ", flush=True)

        # Run Roc benchmark
        roc_duration = run_benchmark(size, roc_file)
        roc_results.append((size, roc_duration))

        # Run Python benchmark
        py_duration = run_benchmark(size, py_file)
        py_results.append((size, py_duration))

        print(f"✓ Roc: {roc_duration:.4f}s, Python: {py_duration:.4f}s")

    print("=" * 60)
    print("Benchmark complete!")
    print()

    # Display results
    print("Results:")
    print(f"{'Size':>8} | {'Roc (s)':>10} | {'Python (s)':>10} | {'Speedup':>10}")
    print("-" * 50)
    for (size, roc_dur), (_, py_dur) in zip(roc_results, py_results):
        speedup = py_dur / roc_dur
        print(f"{size:>8} | {roc_dur:>10.4f} | {py_dur:>10.4f} | {speedup:>9.2f}x")

    # Create plot
    sizes_array = np.array([r[0] for r in roc_results])
    roc_times = np.array([r[1] for r in roc_results])
    py_times = np.array([r[1] for r in py_results])

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    # Plot 1: Log-log plot comparing Roc vs Python
    ax1.loglog(
        sizes_array,
        roc_times,
        "o-",
        linewidth=2,
        markersize=8,
        label="Roc",
        color="blue",
    )
    ax1.loglog(
        sizes_array,
        py_times,
        "s-",
        linewidth=2,
        markersize=8,
        label="Python",
        color="green",
    )

    # Add O(n) reference line (based on Roc's first measurement)
    reference_n = roc_times[0] * (sizes_array / sizes_array[0])
    ax1.loglog(
        sizes_array,
        reference_n,
        "--",
        alpha=0.6,
        linewidth=2,
        label="O(n) reference",
        color="gray",
    )

    ax1.set_xlabel("Input Size (n)", fontsize=12)
    ax1.set_ylabel("Execution Time (seconds)", fontsize=12)
    ax1.set_title(
        "Benchmark: fold_rev with Append\nRoc vs Python (Log-Log Scale)",
        fontsize=14,
        fontweight="bold",
    )
    ax1.grid(True, alpha=0.3, which="both")
    ax1.legend(fontsize=10)

    # Plot 2: Time/n to show constant factor
    roc_normalized = roc_times / sizes_array * 1e3  # ns per n
    py_normalized = py_times / sizes_array * 1e3  # ns per n

    ax2.semilogx(
        sizes_array,
        roc_normalized,
        "o-",
        linewidth=2,
        markersize=8,
        label="Roc",
        color="blue",
    )
    ax2.semilogx(
        sizes_array,
        py_normalized,
        "s-",
        linewidth=2,
        markersize=8,
        label="Python",
        color="green",
    )
    ax2.set_xlabel("Input Size (n)", fontsize=12)
    ax2.set_ylabel("Time / n (ms)", fontsize=12)
    ax2.set_title(
        "Normalized Time per n\n(Should be ~constant for O(n))",
        fontsize=14,
        fontweight="bold",
    )
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()

    # Save plot
    output_file = "benchmark_results.png"
    plt.savefig(output_file, dpi=150, bbox_inches="tight")
    print(f"\nPlot saved to: {output_file}")

    # Show plot
    plt.show()


if __name__ == "__main__":
    main()
