from __future__ import annotations

import math
import random
import sys
from collections import defaultdict
from pathlib import Path

import pandas as pd


WORK = Path(r"C:\Users\yauki\Documents\Codex\2026-06-27\codex-codex-remaining-k-chain-64\work")
OUT = Path(r"C:\Users\yauki\Documents\Codex\2026-06-27\codex-codex-remaining-k-chain-64\outputs")
sys.path.insert(0, str(WORK))

import paradoxical_sequence_analysis as base  # noqa: E402


FROM_BIN = "64-95"
TO_BIN = "32-63"
TRANSITION = f"{FROM_BIN} -> {TO_BIN}"
FEATURE_TYPES = [
    "transition_k",
    "pre_k_window_3",
    "pre_k_window_5",
    "local_rolling_k_window_3",
    "local_rolling_k_window_4",
]
JOINT_CELLS = [
    ("transition_k=1 & pre_k_window_3=1,1,1", lambda f: f["transition_k"] == "1" and f["pre_k_window_3"] == "1,1,1"),
    ("transition_k=1 & local_rolling_k_window_4=1,1,1,1", lambda f: f["transition_k"] == "1" and "1,1,1,1" in f["local_rolling_k_window_4"]),
    ("transition_k=1 & pre_k_window_5=2,1,1,1,1", lambda f: f["transition_k"] == "1" and f["pre_k_window_5"] == "2,1,1,1,1"),
]
COUNTER_PATTERNS = [
    ("transition_k", "2"),
    ("pre_k_window_3", "1,1,3+"),
    ("pre_k_window_3", "1,1,2"),
    ("local_rolling_k_window_4", "2,1,1,1"),
]
LOW_SUPPORT = 0.002


def local_windows(word: tuple[int, ...], pos: int, width: int) -> set[str]:
    out = set()
    last_start = max(0, len(word) - width)
    for start in range(max(0, pos - width + 1), min(pos, last_start) + 1):
        out.add(base.pattern(word[start : start + width], cap=True))
    return out


def occurrence_features(word: tuple[int, ...], pos: int) -> dict[str, object]:
    return {
        "transition_k": base.kcat(word[pos]),
        "pre_k_window_3": base.pattern(word[max(0, pos - 2) : pos + 1], cap=True),
        "pre_k_window_5": base.pattern(word[max(0, pos - 4) : pos + 1], cap=True),
        "local_rolling_k_window_3": local_windows(word, pos, 3),
        "local_rolling_k_window_4": local_windows(word, pos, 4),
    }


def feature_values(features: dict[str, object], feature_type: str) -> list[str]:
    value = features[feature_type]
    if isinstance(value, set):
        return sorted(value)
    return [str(value)]


def add_example(example_keep, label: str, row: dict[str, object], weight: float) -> None:
    example_keep[label].append((weight, row))
    if len(example_keep[label]) > 24:
        example_keep[label].sort(key=lambda item: item[0], reverse=True)
        del example_keep[label][12:]


def classify_counter(row: pd.Series) -> str:
    pass_delta = float(row["pass_share_actual"]) - float(row["pass_share_iid"])
    stay_delta = float(row["stay_share_actual"]) - float(row["stay_share_iid"])
    support = float(row["support"])
    if support < LOW_SUPPORT:
        return "support artifact candidate"
    if pass_delta < 0 and stay_delta >= 0:
        return "actual stay-enriched / iid pass-enriched"
    if pass_delta < 0:
        return "iid-enriched pass pattern"
    if stay_delta > pass_delta:
        return "actual stay-enriched pattern"
    return "mixed"


def analyze_word(
    word: tuple[int, ...],
    source: str,
    power: int,
    h: int,
    weight: float,
    route_counts,
    feature_counts,
    joint_counts,
    counter_counts,
    example_keep,
) -> None:
    info = base.state_info(word, power, h)
    if info is None:
        return
    state, bridge, window, parity = info
    path = base.bin_path(word)
    total_k = sum(word)
    prefix_k = 0
    for pos, k in enumerate(word):
        _from_idx, from_bin = base.remaining_k_bin(total_k - prefix_k)
        _to_idx, to_bin = base.remaining_k_bin(total_k - prefix_k - k)
        if from_bin != FROM_BIN:
            prefix_k += k
            continue
        group = "pass" if to_bin == TO_BIN else ("stay" if to_bin == FROM_BIN else "other")
        if group == "other":
            prefix_k += k
            continue
        route = base.entry_route(path, pos, FROM_BIN)
        features = occurrence_features(word, pos)
        route_counts[(route, source, group)] += weight
        route_counts[("ALL", source, group)] += weight
        if route == "START_IN_LAYER":
            for feature_type in FEATURE_TYPES:
                for value in feature_values(features, feature_type):
                    feature_counts[(feature_type, value, source, group)] += weight
            for label, predicate in JOINT_CELLS:
                if predicate(features):
                    joint_counts[(label, source, group)] += weight
            for feature_type, value in COUNTER_PATTERNS:
                if value in feature_values(features, feature_type):
                    counter_counts[(feature_type, value, source, group)] += weight

        row = {
            "label": "",
            "source": source,
            "group": group,
            "entry_route": route,
            "state": state,
            "bridge": bridge,
            "xk_window": window,
            "parity": parity,
            "power": power,
            "h": h,
            "position": pos,
            "word_length": len(word),
            "total_k": total_k,
            "remaining_K_before": total_k - prefix_k,
            "remaining_K_after": total_k - prefix_k - k,
            "transition_k": features["transition_k"],
            "pre_k_window_3": features["pre_k_window_3"],
            "pre_k_window_5": features["pre_k_window_5"],
            "local_rolling_k_window_3": ";".join(sorted(features["local_rolling_k_window_3"])),
            "local_rolling_k_window_4": ";".join(sorted(features["local_rolling_k_window_4"])),
            "weight": weight,
            "k_prefix_8": base.pattern(word[:8], cap=False),
            "word": base.pattern(word, cap=False),
        }
        if route == "START_IN_LAYER" and features["transition_k"] == "1" and features["pre_k_window_3"] == "1,1,1":
            label = f"{source}_{group}_tk1_pre111"
            ex = dict(row)
            ex["label"] = label
            add_example(example_keep, label, ex, weight)
        if source == "actual" and group == "pass" and features["transition_k"] == "2":
            ex = dict(row)
            ex["label"] = "actual_pass_counter_transition_k_2"
            add_example(example_keep, "actual_pass_counter_transition_k_2", ex, weight)
        prefix_k += k


def rate(pass_mass: float, stay_mass: float) -> float:
    denom = pass_mass + stay_mass
    return pass_mass / denom if denom else math.nan


def summarize_route(route_counts) -> pd.DataFrame:
    rows = []
    all_actual_pass = route_counts[("ALL", "actual", "pass")]
    all_iid_pass = route_counts[("ALL", "iid", "pass")]
    for route in ["ALL", "START_IN_LAYER", "INFLOW_FROM_96-127"]:
        ap = route_counts[(route, "actual", "pass")]
        ast = route_counts[(route, "actual", "stay")]
        ip = route_counts[(route, "iid", "pass")]
        ist = route_counts[(route, "iid", "stay")]
        actual_rate = rate(ap, ast)
        iid_rate = rate(ip, ist)
        rows.append(
            {
                "transition": TRANSITION,
                "entry_route": route,
                "actual_from_mass": ap + ast,
                "iid_from_mass": ip + ist,
                "mass_delta": (ap + ast) - (ip + ist),
                "actual_pass_mass": ap,
                "iid_pass_mass": ip,
                "actual_stay_mass": ast,
                "iid_stay_mass": ist,
                "actual_pass_rate": actual_rate,
                "iid_pass_rate": iid_rate,
                "conditional_delta": actual_rate - iid_rate,
                "support": ap + ast + ip + ist,
                "support_note": "low" if ap + ast + ip + ist < LOW_SUPPORT else "ok",
                "actual_share_of_total_pass_mass": ap / all_actual_pass if all_actual_pass else math.nan,
                "iid_share_of_total_pass_mass": ip / all_iid_pass if all_iid_pass else math.nan,
            }
        )
    return pd.DataFrame(rows)


def summarize_features(feature_counts) -> pd.DataFrame:
    rows = []
    totals = defaultdict(float)
    for (feature_type, value, source, group), mass in feature_counts.items():
        totals[(feature_type, source, group)] += mass
    keys = sorted({(feature_type, value) for feature_type, value, _source, _group in feature_counts})
    for feature_type, value in keys:
        ap = feature_counts[(feature_type, value, "actual", "pass")]
        ast = feature_counts[(feature_type, value, "actual", "stay")]
        ip = feature_counts[(feature_type, value, "iid", "pass")]
        ist = feature_counts[(feature_type, value, "iid", "stay")]
        ap_share = ap / totals[(feature_type, "actual", "pass")] if totals[(feature_type, "actual", "pass")] else math.nan
        ip_share = ip / totals[(feature_type, "iid", "pass")] if totals[(feature_type, "iid", "pass")] else math.nan
        ast_share = ast / totals[(feature_type, "actual", "stay")] if totals[(feature_type, "actual", "stay")] else math.nan
        ist_share = ist / totals[(feature_type, "iid", "stay")] if totals[(feature_type, "iid", "stay")] else math.nan
        rows.append(
            {
                "transition": TRANSITION,
                "entry_route": "START_IN_LAYER",
                "feature_type": feature_type,
                "pattern": value,
                "actual_pass": ap,
                "iid_pass": ip,
                "actual_stay": ast,
                "iid_stay": ist,
                "actual_pass_share": ap_share,
                "iid_pass_share": ip_share,
                "pass_share_delta": ap_share - ip_share,
                "actual_stay_share": ast_share,
                "iid_stay_share": ist_share,
                "stay_share_delta": ast_share - ist_share,
                "support": ap + ast + ip + ist,
                "support_note": "low" if ap + ast + ip + ist < LOW_SUPPORT else "ok",
            }
        )
    return pd.DataFrame(rows)


def summarize_joint(joint_counts, counter_counts, route_counts) -> pd.DataFrame:
    rows = []
    containers = [
        ("joint", label, joint_counts)
        for label, _predicate in JOINT_CELLS
    ] + [
        ("counter", f"{feature_type}={value}", counter_counts)
        for feature_type, value in COUNTER_PATTERNS
    ]
    start_actual_pass = route_counts[("START_IN_LAYER", "actual", "pass")]
    start_iid_pass = route_counts[("START_IN_LAYER", "iid", "pass")]
    start_actual_stay = route_counts[("START_IN_LAYER", "actual", "stay")]
    start_iid_stay = route_counts[("START_IN_LAYER", "iid", "stay")]
    for kind, label, counts in containers:
        ap = counts[(label, "actual", "pass")] if kind == "joint" else counts[tuple(label.split("=", 1)) + ("actual", "pass")]
        ast = counts[(label, "actual", "stay")] if kind == "joint" else counts[tuple(label.split("=", 1)) + ("actual", "stay")]
        ip = counts[(label, "iid", "pass")] if kind == "joint" else counts[tuple(label.split("=", 1)) + ("iid", "pass")]
        ist = counts[(label, "iid", "stay")] if kind == "joint" else counts[tuple(label.split("=", 1)) + ("iid", "stay")]
        actual_rate = rate(ap, ast)
        iid_rate = rate(ip, ist)
        rows.append(
            {
                "transition": TRANSITION,
                "entry_route": "START_IN_LAYER",
                "pattern_kind": kind,
                "pattern": label,
                "actual_pass_rate": actual_rate,
                "iid_pass_rate": iid_rate,
                "conditional_delta": actual_rate - iid_rate,
                "actual_support": ap + ast,
                "iid_support": ip + ist,
                "support": ap + ast + ip + ist,
                "support_note": "low" if ap + ast + ip + ist < LOW_SUPPORT else "ok",
                "actual_pass_mass": ap,
                "iid_pass_mass": ip,
                "actual_stay_mass": ast,
                "iid_stay_mass": ist,
                "pass_share_actual": ap / start_actual_pass if start_actual_pass else math.nan,
                "pass_share_iid": ip / start_iid_pass if start_iid_pass else math.nan,
                "stay_share_actual": ast / start_actual_stay if start_actual_stay else math.nan,
                "stay_share_iid": ist / start_iid_stay if start_iid_stay else math.nan,
            }
        )
    df = pd.DataFrame(rows)
    df["counter_reading"] = ""
    for idx, row in df[df["pattern_kind"] == "counter"].iterrows():
        df.loc[idx, "counter_reading"] = classify_counter(row)
    return df


def examples_df(example_keep) -> pd.DataFrame:
    rows = []
    for label, items in sorted(example_keep.items()):
        items.sort(key=lambda item: item[0], reverse=True)
        rows.extend(row for _weight, row in items[:12])
    return pd.DataFrame(rows)


def fmt(x: float) -> str:
    if pd.isna(x):
        return "nan"
    return f"{float(x):.6g}"


def build_report(route: pd.DataFrame, features: pd.DataFrame, joint: pd.DataFrame, examples: pd.DataFrame) -> str:
    lines = [
        "# Paradoxical 64-95 -> 32-63 Deep Dive",
        "",
        "## Purpose",
        "",
        "Focused descriptive decomposition of the cleanest paradoxical transition, `64-95 -> 32-63`. The goal is to identify which sequence-level components carry the actual conditional excess. This does not claim a mechanism.",
        "",
        "## Entry Route Split",
        "",
    ]
    for r in route.itertuples(index=False):
        lines.append(
            f"- `{r.entry_route}`: actual from `{fmt(r.actual_from_mass)}`, iid from `{fmt(r.iid_from_mass)}`, actual pass `{fmt(r.actual_pass_rate)}`, iid pass `{fmt(r.iid_pass_rate)}`, conditional delta `{fmt(r.conditional_delta)}`, actual pass-share `{fmt(r.actual_share_of_total_pass_mass)}`, iid pass-share `{fmt(r.iid_share_of_total_pass_mass)}`, support `{fmt(r.support)}` ({r.support_note})."
        )
    lines.extend(
        [
            "",
            "The readable component is `START_IN_LAYER`: it carries most pass mass and has positive conditional delta. `INFLOW_FROM_96-127` has substantial from-mass deficit but does not carry the conditional excess; its conditional delta is slightly negative.",
            "",
            "## START_IN_LAYER Pass vs Stay",
            "",
        ]
    )
    for feature_type in FEATURE_TYPES:
        top = features[features["feature_type"] == feature_type].copy()
        top["abs_pass_share_delta"] = top["pass_share_delta"].abs()
        top = top.sort_values("abs_pass_share_delta", ascending=False).head(8)
        lines.append(f"### {feature_type}")
        for r in top.itertuples(index=False):
            lines.append(
                f"- `{r.pattern}`: pass-share delta `{fmt(r.pass_share_delta)}`, stay-share delta `{fmt(r.stay_share_delta)}`, support `{fmt(r.support)}` ({r.support_note})."
            )
        lines.append("")
    lines.extend(["## Joint Pattern Cells", ""])
    for r in joint[joint["pattern_kind"] == "joint"].itertuples(index=False):
        lines.append(
            f"- `{r.pattern}`: actual pass `{fmt(r.actual_pass_rate)}`, iid pass `{fmt(r.iid_pass_rate)}`, conditional delta `{fmt(r.conditional_delta)}`, pass share actual/iid `{fmt(r.pass_share_actual)}` / `{fmt(r.pass_share_iid)}`, stay share actual/iid `{fmt(r.stay_share_actual)}` / `{fmt(r.stay_share_iid)}`, support `{fmt(r.support)}` ({r.support_note})."
        )
    lines.extend(["", "## Counter-Patterns", ""])
    for r in joint[joint["pattern_kind"] == "counter"].itertuples(index=False):
        lines.append(
            f"- `{r.pattern}`: conditional delta `{fmt(r.conditional_delta)}`, pass share actual/iid `{fmt(r.pass_share_actual)}` / `{fmt(r.pass_share_iid)}`, stay share actual/iid `{fmt(r.stay_share_actual)}` / `{fmt(r.stay_share_iid)}`, support `{fmt(r.support)}` ({r.support_note}); reading: {r.counter_reading}."
        )
    lines.extend(
        [
            "",
            "## Representative Sequences",
            "",
            f"`paradoxical_64_95_examples.csv` contains `{len(examples)}` selected rows across requested labels.",
            "",
            "## Short Reading",
            "",
            "- Conditional excess is concentrated in `START_IN_LAYER`, not in upstream inflow from `96-127`.",
            "- Within `START_IN_LAYER`, the strongest readable positive component is `transition_k=1` together with short all-1 local context, especially `pre_k_window_3=1,1,1` and `local_rolling_k_window_4=1,1,1,1`.",
            "- Counter-patterns are mostly the complement of that component: `transition_k=2` and `1,1,2` / `1,1,3+` are depleted on the actual pass side and are better read as iid-enriched pass or actual-stay-associated patterns, not as support artifacts.",
            "- This supports treating `64-95 -> 32-63` as the cleanest paradoxical transition: the from-band is thin in actual, but the conditional excess has a compact sequence-level carrier.",
            "",
            "## Limits",
            "",
            "- Same sampled long-word / final-state-defined universe as the previous remaining_K chain analysis.",
            "- Low support is marked explicitly in the CSVs.",
            "- The analysis describes concentration and correspondence. It is not a mechanism claim.",
        ]
    )
    return "\n".join(lines) + "\n"


def main() -> None:
    OUT.mkdir(exist_ok=True)
    module = base.load_source()
    rng = random.Random(base.SEED)
    tau_cuts: dict[int, tuple[int, int]] = {}
    route_counts = defaultdict(float)
    feature_counts = defaultdict(float)
    joint_counts = defaultdict(float)
    counter_counts = defaultdict(float)
    example_keep = defaultdict(list)

    print("sampling iid", flush=True)
    for h in base.HS:
        sampled = base.sample_iid(module, h, rng)
        tau_cuts[h] = base.weighted_tau_cuts(sampled)
        for word, weight in sampled:
            if len(word) <= tau_cuts[h][1]:
                continue
            for power in base.POWERS:
                analyze_word(word, "iid", power, h, weight, route_counts, feature_counts, joint_counts, counter_counts, example_keep)
        print(f"iid h={h} done", flush=True)

    print("sampling actual", flush=True)
    for power in base.POWERS:
        status = base.load_status(power)
        for h in base.HS:
            _cut1, cut2 = tau_cuts[h]
            lo, hi, total = module.layer_bounds(power, h)
            escape_indices = [idx for idx in range(lo >> 1, (hi >> 1) + 1) if status[idx] == module.ESCAPE]
            chosen = base.evenly_spaced(escape_indices, base.ACTUAL_SAMPLE_PER_PH)
            sample_weight = len(escape_indices) / (len(chosen) * total) if chosen else 0.0
            for idx in chosen:
                word = module.trace_escape(2 * idx + 1, 1 << power)
                if len(word) <= cut2:
                    continue
                analyze_word(word, "actual", power, h, sample_weight, route_counts, feature_counts, joint_counts, counter_counts, example_keep)
            print(f"actual power={power} h={h} done", flush=True)
        del status

    print("writing outputs", flush=True)
    route = summarize_route(route_counts)
    features = summarize_features(feature_counts)
    joint = summarize_joint(joint_counts, counter_counts, route_counts)
    examples = examples_df(example_keep)
    route.to_csv(OUT / "paradoxical_64_95_entry_route.csv", index=False, encoding="utf-8-sig")
    joint.to_csv(OUT / "paradoxical_64_95_joint_patterns.csv", index=False, encoding="utf-8-sig")
    examples.to_csv(OUT / "paradoxical_64_95_examples.csv", index=False, encoding="utf-8-sig")
    features.to_csv(OUT / "paradoxical_64_95_start_in_layer_features.csv", index=False, encoding="utf-8-sig")
    (OUT / "paradoxical_64_95_deep_dive.md").write_text(build_report(route, features, joint, examples), encoding="utf-8")
    print(OUT / "paradoxical_64_95_deep_dive.md", flush=True)


if __name__ == "__main__":
    main()
