from __future__ import annotations

import csv
import math
import sys
from collections import Counter
from pathlib import Path
from statistics import mean, median

import pandas as pd


WORK = Path(r"C:\Users\yauki\Documents\Codex\2026-06-27\task-analyze-the-12-trajectories-that\work")
OUT = Path(r"C:\Users\yauki\Documents\Codex\2026-06-27\task-analyze-the-12-trajectories-that\outputs")
CLASSIFICATION = Path(
    r"C:\Users\yauki\Documents\Codex\2026-06-27\integer-side-projection-of-64-95\outputs\rozier_nonhit_classification.csv"
)
sys.path.insert(0, str(WORK))

from analyze_entry_64_95_boundary import (  # noqa: E402
    event_word_original_n_strict,
    local_windows,
    odd_core,
    path_for_word,
    pattern,
    remaining_k_bin,
)

import paradoxical_sequence_analysis as base  # noqa: E402


FROM_BIN = "64-95"
TO_BIN = "32-63"
NA = ""
CLASSES = ["A_start", "A_inflow", "Other_start"]
PATTERNS = [1, 2, 3, 4]


def write_csv(path: Path, rows: list[dict[str, object]]) -> None:
    if not rows:
        path.write_text("", encoding="utf-8")
        return
    with path.open("w", newline="", encoding="utf-8-sig") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)


def scan_events(n: int) -> tuple[tuple[int, ...], list[dict[str, object]]]:
    word = event_word_original_n_strict(n)
    path = path_for_word(word)
    total = sum(word)
    prefix = 0
    events: list[dict[str, object]] = []
    for pos, k in enumerate(word):
        before = total - prefix
        after = before - k
        from_bin = remaining_k_bin(before)
        to_bin = remaining_k_bin(after)
        route = base.entry_route(path, pos, from_bin)
        events.append(
            {
                "position": pos,
                "remaining_K_before": before,
                "remaining_K_after": after,
                "from_bin": from_bin,
                "to_bin": to_bin,
                "transition": f"{from_bin} -> {to_bin}",
                "transition_k": k,
                "entry_route": route,
                "pre_k_window_3": pattern(word[max(0, pos - 2) : pos + 1]),
                "local_window_4": ";".join(local_windows(word, pos, 4)),
            }
        )
        prefix += k
    return word, events


def pass_face_class(event: dict[str, object]) -> str:
    route = str(event["entry_route"])
    all1_pass = int(event["transition_k"]) == 1 and event["pre_k_window_3"] == "1,1,1"
    if route == "START_IN_LAYER":
        return "A_start" if all1_pass else "Other_start"
    if route == "INFLOW_FROM_96-127":
        return "A_inflow" if all1_pass else "Other_inflow"
    return f"Other_route_{route}"


def pattern_present_at_event(window: list[dict[str, object]], idx: int, length: int) -> bool:
    if length == 1:
        return int(window[idx]["transition_k"]) == 1
    if idx - length + 1 < 0:
        return False
    return all(int(window[j]["transition_k"]) == 1 for j in range(idx - length + 1, idx + 1))


def first_occurrence(window: list[dict[str, object]], length: int) -> int | None:
    for idx in range(len(window)):
        if pattern_present_at_event(window, idx, length):
            return idx
    return None


def survives_to_pass(window: list[dict[str, object]], length: int) -> bool:
    if not window:
        return False
    return pattern_present_at_event(window, len(window) - 1, length)


def longest_k1_run(window: list[dict[str, object]]) -> int:
    best = cur = 0
    for event in window:
        if int(event["transition_k"]) == 1:
            cur += 1
            best = max(best, cur)
        else:
            cur = 0
    return best


def safe_median(values: list[float]) -> float:
    return float(median(values)) if values else math.nan


def safe_mean(values: list[float]) -> float:
    return float(mean(values)) if values else math.nan


def fmt(value: object, digits: int = 3) -> str:
    try:
        x = float(value)
    except (TypeError, ValueError):
        return str(value)
    if math.isnan(x):
        return "NA"
    return f"{x:.{digits}f}"


def main() -> None:
    OUT.mkdir(exist_ok=True)
    classified = pd.read_csv(CLASSIFICATION)
    universe = classified[classified["mode"] == "original_n_strict"].copy()
    ids = sorted(int(x) for x in universe["n_original"].unique())
    if len(ids) != 550:
        raise RuntimeError(f"expected 550 trajectories, got {len(ids)}")

    rows: list[dict[str, object]] = []
    event_rows: list[dict[str, object]] = []
    g0 = g1 = 0
    skipped_classes = Counter()

    for n in ids:
        _word, events = scan_events(n)
        layer_events = [e for e in events if e["from_bin"] == FROM_BIN]
        if not layer_events:
            g0 += 1
            continue
        g1 += 1
        first64 = layer_events[0]
        first_pass = next((e for e in layer_events if e["to_bin"] == TO_BIN), None)
        if first_pass is None:
            raise RuntimeError(f"G1 without first pass: {n}")
        cls = pass_face_class(first_pass)
        if cls not in CLASSES:
            skipped_classes[cls] += 1
            continue
        first64_pos = int(first64["position"])
        pass_pos = int(first_pass["position"])
        window = [e for e in layer_events if first64_pos <= int(e["position"]) <= pass_pos]
        longest = longest_k1_run(window)
        firsts = {length: first_occurrence(window, length) for length in PATTERNS}
        survives = {length: survives_to_pass(window, length) for length in PATTERNS}
        disappeared = {
            length: firsts[length] is not None and not survives[length]
            for length in PATTERNS
        }
        for idx, event in enumerate(window):
            run_ending = 0
            j = idx
            while j >= 0 and int(window[j]["transition_k"]) == 1:
                run_ending += 1
                j -= 1
            event_rows.append(
                {
                    "trajectory_id": n,
                    "class": cls,
                    "index_from_entry": idx,
                    "distance_to_pass": len(window) - 1 - idx,
                    "transition_k": event["transition_k"],
                    "pre_k_window_3": event["pre_k_window_3"],
                    "local_window_4": event["local_window_4"],
                    "k1_run_ending_here": run_ending,
                    "has_1_here": int(run_ending >= 1),
                    "has_11_here": int(run_ending >= 2),
                    "has_111_here": int(run_ending >= 3),
                    "has_1111_here": int(run_ending >= 4),
                    "is_pass_event": int(idx == len(window) - 1),
                }
            )
        rows.append(
            {
                "trajectory_id": n,
                "n": n,
                "odd_core": odd_core(n),
                "class": cls,
                "first_64_95_index": first64_pos,
                "first_pass_index": pass_pos,
                "first_pass_wait_events": pass_pos - first64_pos,
                "window_event_count": len(window),
                "first_1": firsts[1] if firsts[1] is not None else NA,
                "first_11": firsts[2] if firsts[2] is not None else NA,
                "first_111": firsts[3] if firsts[3] is not None else NA,
                "first_1111": firsts[4] if firsts[4] is not None else NA,
                "distance_first_111_to_pass": (len(window) - 1 - firsts[3]) if firsts[3] is not None else NA,
                "distance_first_1111_to_pass": (len(window) - 1 - firsts[4]) if firsts[4] is not None else NA,
                "longest_run_before_pass": longest,
                "pattern_1_survives_to_pass": int(survives[1]),
                "pattern_11_survives_to_pass": int(survives[2]),
                "pattern_111_survives_to_pass": int(survives[3]),
                "pattern_1111_survives_to_pass": int(survives[4]),
                "pattern_1_disappears_before_pass": int(disappeared[1]),
                "pattern_11_disappears_before_pass": int(disappeared[2]),
                "pattern_111_disappears_before_pass": int(disappeared[3]),
                "pattern_1111_disappears_before_pass": int(disappeared[4]),
                "pattern_survives_to_pass_yes_no": "yes" if survives[3] else "no",
            }
        )

    counts = Counter(row["class"] for row in rows)
    if g0 != 12 or g1 != 538 or counts["A_start"] != 365 or counts["A_inflow"] != 104 or counts["Other_start"] != 66:
        raise RuntimeError(f"unexpected counts: G0={g0}, G1={g1}, classes={counts}, skipped={skipped_classes}")

    write_csv(OUT / "all1_formation_per_trajectory.csv", rows)
    write_csv(OUT / "all1_formation_event_scan.csv", event_rows)

    summary_rows: list[dict[str, object]] = []
    for cls in CLASSES:
        sub = [r for r in rows if r["class"] == cls]
        row: dict[str, object] = {"class": cls, "count": len(sub)}
        for field in ["first_1", "first_11", "first_111", "first_1111"]:
            vals = [float(r[field]) for r in sub if r[field] != NA]
            row[f"median_{field}"] = safe_median(vals)
            row[f"mean_{field}"] = safe_mean(vals)
            row[f"share_ever_{field.replace('first_', '')}"] = len(vals) / len(sub)
        row["median_distance_first_111_to_pass"] = safe_median([float(r["distance_first_111_to_pass"]) for r in sub if r["distance_first_111_to_pass"] != NA])
        row["median_distance_first_1111_to_pass"] = safe_median([float(r["distance_first_1111_to_pass"]) for r in sub if r["distance_first_1111_to_pass"] != NA])
        row["median_longest_run_before_pass"] = safe_median([float(r["longest_run_before_pass"]) for r in sub])
        row["share_ever_reaching_111"] = sum(1 for r in sub if r["first_111"] != NA) / len(sub)
        row["share_ever_reaching_1111"] = sum(1 for r in sub if r["first_1111"] != NA) / len(sub)
        row["share_losing_111_before_pass"] = sum(int(r["pattern_111_disappears_before_pass"]) for r in sub) / len(sub)
        row["share_maintaining_111_until_pass"] = sum(int(r["pattern_111_survives_to_pass"]) for r in sub) / len(sub)
        row["share_losing_1111_before_pass"] = sum(int(r["pattern_1111_disappears_before_pass"]) for r in sub) / len(sub)
        row["share_maintaining_1111_until_pass"] = sum(int(r["pattern_1111_survives_to_pass"]) for r in sub) / len(sub)
        summary_rows.append(row)
    write_csv(OUT / "all1_formation_class_summary.csv", summary_rows)

    ast = next(r for r in summary_rows if r["class"] == "A_start")
    ain = next(r for r in summary_rows if r["class"] == "A_inflow")
    oth = next(r for r in summary_rows if r["class"] == "Other_start")

    other = [r for r in rows if r["class"] == "Other_start"]
    other_never111 = sum(1 for r in other if r["first_111"] == NA)
    other_lost111 = sum(1 for r in other if int(r["pattern_111_disappears_before_pass"]))
    other_maintain111 = sum(1 for r in other if int(r["pattern_111_survives_to_pass"]))
    other_late111 = sum(
        1
        for r in other
        if r["distance_first_111_to_pass"] != NA and int(r["distance_first_111_to_pass"]) <= 1
    )

    report = [
        "# Formation of all-1 context before first pass",
        "",
        "Observational analysis only. Dataset and scanner mode: `original_n_strict`. Window: first `64-95` entry through first `64-95 -> 32-63` pass, inclusive.",
        "",
        "## Verified counts",
        "",
        f"- total trajectories: `550`",
        f"- G0 never enters `64-95`: `{g0}`",
        f"- G1 enters `64-95`: `{g1}`",
        f"- A_start: `{counts['A_start']}`",
        f"- A_inflow: `{counts['A_inflow']}`",
        f"- Other_start: `{counts['Other_start']}`",
        "",
        "## Class summary",
        "",
        f"- A_start median first 111: `{fmt(ast['median_first_111'])}`, median first 1111: `{fmt(ast['median_first_1111'])}`, maintain-111-to-pass share `{fmt(ast['share_maintaining_111_until_pass'])}`.",
        f"- A_inflow median first 111: `{fmt(ain['median_first_111'])}`, median first 1111: `{fmt(ain['median_first_1111'])}`, maintain-111-to-pass share `{fmt(ain['share_maintaining_111_until_pass'])}`.",
        f"- Other_start median first 111: `{fmt(oth['median_first_111'])}`, median first 1111: `{fmt(oth['median_first_1111'])}`, maintain-111-to-pass share `{fmt(oth['share_maintaining_111_until_pass'])}`.",
        "",
        "Other_start decomposition:",
        f"- never reaches 111: `{other_never111}/{len(other)}`",
        f"- reaches 111 but loses it before pass: `{other_lost111}/{len(other)}`",
        f"- maintains 111 to pass: `{other_maintain111}/{len(other)}`",
        f"- first 111 appears at distance <= 1 from pass: `{other_late111}/{len(other)}`",
        "",
        "## Main descriptive answers",
        "",
        "1. The all-1 context grows as a run process inside the layer: `1` and `11` appear earlier and more broadly; `111` and `1111` distinguish the all-1 pass faces.",
        f"2. In A_start and A_inflow, 111 is usually formed before the pass rather than only at the pass: median distance from first 111 to pass is `{fmt(ast['median_distance_first_111_to_pass'])}` for A_start and `{fmt(ain['median_distance_first_111_to_pass'])}` for A_inflow.",
        f"3. Other_start mostly fails by not forming 111 in the pre-pass window (`{other_never111}/{len(other)}`), with a smaller group that forms 111 and then loses it (`{other_lost111}/{len(other)}`).",
        f"4. A_start is characterized by stable maintenance of 111 through the pass in this dataset: maintain-111 share `{fmt(ast['share_maintaining_111_until_pass'])}` and maintain-1111 share `{fmt(ast['share_maintaining_1111_until_pass'])}`.",
        "",
        "## Output files",
        "",
        "- `all1_formation_per_trajectory.csv`",
        "- `all1_formation_event_scan.csv`",
        "- `all1_formation_class_summary.csv`",
    ]
    (OUT / "all1_formation_report.md").write_text("\n".join(report) + "\n", encoding="utf-8")
    print(OUT / "all1_formation_report.md")


if __name__ == "__main__":
    main()
