from __future__ import annotations

import importlib.util
import math
import random
from collections import Counter, defaultdict
from pathlib import Path

import pandas as pd


SRC = Path(r"C:\Users\yauki\Documents\design\Collatz\py\2026-06-24_Collatz\python\collatz_escape_word_deficit.py")
OUT = Path(r"C:\Users\yauki\Documents\Codex\2026-06-27\codex-codex-remaining-k-chain-64\outputs")

POWERS = [24, 25, 26, 27, 28]
HS = [2, 3, 4, 5, 6]
IID_SAMPLES_PER_H = 160_000
ACTUAL_SAMPLE_PER_PH = 20_000
SEED = 20260625
LOG2_3 = math.log2(3.0)
Q_LOW = -1.5
Q_HIGH = -0.25
FOCUS_STATES = [
    "late_growth|deep_32_63|even",
    "late_growth|deep_32_63|odd",
    "late_growth|exhaustion_0_31|odd",
]
TRANSITIONS = [
    ("96-127", "64-95"),
    ("64-95", "32-63"),
    ("32-63", "16-31"),
]
BIN_LABELS = ["0-1", "2-3", "4-7", "8-15", "16-31", "32-63", "64-95", "96-127", "128-191", "192+"]
CACHE_DIRS = [
    Path(r"C:\Users\yauki\Documents\Codex\2026-06-25\new-chat-3\work\status_cache"),
    Path(r"C:\Users\yauki\Documents\Codex\2026-06-25\new-chat-2\work\status_cache"),
    Path(r"C:\Users\yauki\Documents\Codex\2026-06-25\new-chat\work\status_cache"),
]


def load_source():
    spec = importlib.util.spec_from_file_location("collatz_escape_word_deficit", SRC)
    if spec is None or spec.loader is None:
        raise RuntimeError(f"cannot load {SRC}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def load_status(power: int) -> bytearray:
    name = f"odd_only_status_p{power}.bin"
    expected = 1 << (power - 1)
    for directory in CACHE_DIRS:
        path = directory / name
        if path.exists() and path.stat().st_size == expected:
            return bytearray(path.read_bytes())
    raise FileNotFoundError(f"missing cached status for p={power}")


def evenly_spaced(items: list[int], limit: int) -> list[int]:
    if len(items) <= limit:
        return items
    if limit <= 1:
        return [items[len(items) // 2]]
    return [items[round(i * (len(items) - 1) / (limit - 1))] for i in range(limit)]


def sample_iid(module, h: int, rng: random.Random) -> list[tuple[tuple[int, ...], float]]:
    out = []
    for _ in range(IID_SAMPLES_PER_H):
        y = 0.5 + 0.5 * rng.random()
        distance = h - math.log2(y)
        position = 0.0
        word: list[int] = []
        while position <= distance:
            k = module.tilted_k(rng)
            word.append(k)
            position += LOG2_3 - k
        overshoot = position - distance
        weight = (2.0 ** (-h)) * y * (2.0 ** (-overshoot)) / IID_SAMPLES_PER_H
        out.append((tuple(word), weight))
    return out


def weighted_tau_cuts(words: list[tuple[tuple[int, ...], float]]) -> tuple[int, int]:
    by_tau: Counter[int] = Counter()
    for word, weight in words:
        by_tau[len(word)] += weight
    total = sum(by_tau.values())
    cuts = []
    acc = 0.0
    targets = [total / 3, 2 * total / 3]
    for tau, weight in sorted(by_tau.items()):
        acc += weight
        while targets and acc >= targets[0]:
            cuts.append(tau)
            targets.pop(0)
    while len(cuts) < 2:
        cuts.append(max(by_tau))
    return cuts[0], cuts[1]


def path_xs(word: tuple[int, ...]) -> list[float]:
    xs = [0.0]
    cur = 0.0
    for k in word:
        cur += LOG2_3 - k
        xs.append(cur)
    return xs


def interp(values: list[float], u: float) -> float:
    tau = len(values) - 1
    pos = u * tau
    lo = int(math.floor(pos))
    hi = min(tau, lo + 1)
    frac = pos - lo
    return values[lo] * (1.0 - frac) + values[hi] * frac


def z25_feature(word: tuple[int, ...]) -> float:
    xs = path_xs(word)
    final = xs[-1]
    return interp(xs, 0.25) - 0.25 * final


def cluster_from_z(z: float) -> str:
    if z <= Q_LOW:
        return "late_growth"
    if z >= Q_HIGH:
        return "early_growth"
    return "balanced"


def xk_window(x_k: int) -> str | None:
    if 0 <= x_k < 32:
        return "exhaustion_0_31"
    if 32 <= x_k < 64:
        return "deep_32_63"
    if 64 <= x_k < 96:
        return "tail_64_95"
    return None


def state_info(word: tuple[int, ...], power: int, h: int) -> tuple[str, str, str, str] | None:
    x_k = sum(word) - (power - h)
    window = xk_window(x_k)
    if window is None:
        return None
    bridge = cluster_from_z(z25_feature(word))
    parity = "even" if power % 2 == 0 else "odd"
    return f"{bridge}|{window}|{parity}", bridge, window, parity


def remaining_k_bin(value: int) -> tuple[int, str]:
    bounds = [0, 2, 4, 8, 16, 32, 64, 96, 128, 192]
    for i, lo in enumerate(bounds):
        hi = bounds[i + 1] if i + 1 < len(bounds) else math.inf
        if lo <= value < hi:
            return i, BIN_LABELS[i]
    return 0, "<0"


def kcat(k: int) -> str:
    return str(k) if k <= 2 else "3+"


def pattern(values: tuple[int, ...] | list[int], cap: bool = False) -> str:
    if not values:
        return "START"
    if cap:
        return ",".join(kcat(v) for v in values)
    return ",".join(str(v) for v in values)


def bin_path(word: tuple[int, ...]) -> list[str]:
    total_k = sum(word)
    prefix_k = 0
    out = []
    for k in word:
        _idx, from_bin = remaining_k_bin(total_k - prefix_k)
        out.append(from_bin)
        prefix_k += k
    return out


def entry_route(path: list[str], pos: int, from_bin: str) -> str:
    first = next((i for i, b in enumerate(path) if b == from_bin), pos)
    if first == 0:
        return "START_IN_LAYER"
    prev = path[first - 1]
    return f"INFLOW_FROM_{prev}"


def add_feature(feature_counts, transition: str, source: str, group: str, feature_type: str, value: str, weight: float) -> None:
    feature_counts[(transition, feature_type, value, source, group)] += weight


def add_occurrence_features(feature_counts, transition: str, source: str, group: str, word: tuple[int, ...], pos: int, weight: float) -> None:
    add_feature(feature_counts, transition, source, group, "k_prefix_3", pattern(word[:3], cap=True), weight)
    add_feature(feature_counts, transition, source, group, "k_prefix_5", pattern(word[:5], cap=True), weight)
    add_feature(feature_counts, transition, source, group, "pre_k_window_3", pattern(word[max(0, pos - 2) : pos + 1], cap=True), weight)
    add_feature(feature_counts, transition, source, group, "pre_k_window_5", pattern(word[max(0, pos - 4) : pos + 1], cap=True), weight)
    add_feature(feature_counts, transition, source, group, "transition_k", kcat(word[pos]), weight)
    add_feature(feature_counts, transition, source, group, "pre_raw_k_window_3", pattern(word[max(0, pos - 2) : pos + 1], cap=False), weight)
    for width in (3, 4):
        seen = set()
        for start in range(max(0, pos - width + 1), min(pos + 1, max(0, len(word) - width + 1)) + 1):
            seen.add(pattern(word[start : start + width], cap=True))
        for value in seen:
            add_feature(feature_counts, transition, source, group, f"local_rolling_k_window_{width}", value, weight)


def add_record(records, transition: str, source: str, group: str, state: str, route: str, weight: float) -> None:
    records[(transition, source, group, state, route)] += weight


def analyze_word(
    word: tuple[int, ...],
    source: str,
    power: int,
    h: int,
    weight: float,
    feature_counts,
    focus_counts,
    route_counts,
    example_rows,
    example_keep,
) -> None:
    info = state_info(word, power, h)
    if info is None:
        return
    state, bridge, window, parity = info
    path = bin_path(word)
    total_k = sum(word)
    prefix_k = 0
    for pos, k in enumerate(word):
        _from_idx, from_bin = remaining_k_bin(total_k - prefix_k)
        _to_idx, to_bin = remaining_k_bin(total_k - prefix_k - k)
        for target_from, target_to in TRANSITIONS:
            if from_bin != target_from:
                continue
            transition = f"{target_from} -> {target_to}"
            group = "pass" if to_bin == target_to else ("stay" if to_bin == target_from else "other")
            if group == "other":
                continue
            route = entry_route(path, pos, target_from)
            add_record(focus_counts, transition, source, group, state, route, weight)
            add_record(route_counts, transition, source, group, "ALL", route, weight)
            add_occurrence_features(feature_counts, transition, source, group, word, pos, weight)
            key = (transition, source, group)
            item = (
                weight,
                {
                    "transition": transition,
                    "source": source,
                    "group": group,
                    "entry_route": route,
                    "state": state,
                    "bridge": bridge,
                    "xk_window": window,
                    "parity": parity,
                    "power": power,
                    "h": h,
                    "position": pos,
                    "word_length": len(word),
                    "total_k": total_k,
                    "remaining_K_before": total_k - prefix_k,
                    "remaining_K_after": total_k - prefix_k - k,
                    "weight": weight,
                    "k_prefix_8": pattern(word[:8], cap=False),
                    "k_prefix_8_capped": pattern(word[:8], cap=True),
                    "pre_k_window_8": pattern(word[max(0, pos - 7) : pos + 1], cap=False),
                    "word": pattern(word, cap=False),
                },
            )
            example_keep[key].append(item)
            if len(example_keep[key]) > 24:
                example_keep[key].sort(key=lambda x: x[0], reverse=True)
                del example_keep[key][12:]
        prefix_k += k


def feature_rows(feature_counts) -> pd.DataFrame:
    keys = sorted({(t, ft, value) for t, ft, value, _source, _group in feature_counts})
    totals = defaultdict(float)
    for (transition, feature_type, value, source, group), mass in feature_counts.items():
        totals[(transition, feature_type, source, group)] += mass
    rows = []
    for transition, feature_type, value in keys:
        ap = feature_counts[(transition, feature_type, value, "actual", "pass")]
        ip = feature_counts[(transition, feature_type, value, "iid", "pass")]
        ast = feature_counts[(transition, feature_type, value, "actual", "stay")]
        ist = feature_counts[(transition, feature_type, value, "iid", "stay")]
        ap_total = totals[(transition, feature_type, "actual", "pass")]
        ip_total = totals[(transition, feature_type, "iid", "pass")]
        ast_total = totals[(transition, feature_type, "actual", "stay")]
        ist_total = totals[(transition, feature_type, "iid", "stay")]
        ap_share = ap / ap_total if ap_total else math.nan
        ip_share = ip / ip_total if ip_total else math.nan
        ast_share = ast / ast_total if ast_total else math.nan
        ist_share = ist / ist_total if ist_total else math.nan
        pass_delta = ap_share - ip_share
        stay_delta = ast_share - ist_share
        support = ap + ip + ast + ist
        rows.append(
            {
                "transition": transition,
                "feature_type": feature_type,
                "pattern": value,
                "actual_pass": ap,
                "iid_pass": ip,
                "actual_stay": ast,
                "iid_stay": ist,
                "actual_pass_share": ap_share,
                "iid_pass_share": ip_share,
                "pass_share_delta": pass_delta,
                "actual_stay_share": ast_share,
                "iid_stay_share": ist_share,
                "stay_share_delta": stay_delta,
                "pass_minus_stay_delta": pass_delta - stay_delta,
                "support": support,
                "support_note": "low" if support < 0.002 else "ok",
            }
        )
    return pd.DataFrame(rows).sort_values(["transition", "feature_type", "pass_share_delta"], ascending=[True, True, False])


def focus_rows(focus_counts) -> pd.DataFrame:
    states = sorted({state for _t, _s, _g, state, _route in focus_counts})
    transitions = [f"{a} -> {b}" for a, b in TRANSITIONS]
    routes = sorted({route for _t, _s, _g, _state, route in focus_counts})
    rows = []
    for transition in transitions:
        for state in states:
            for route in ["ALL", *routes]:
                route_filter = routes if route == "ALL" else [route]
                ap = sum(focus_counts[(transition, "actual", "pass", state, r)] for r in route_filter)
                ip = sum(focus_counts[(transition, "iid", "pass", state, r)] for r in route_filter)
                ast = sum(focus_counts[(transition, "actual", "stay", state, r)] for r in route_filter)
                ist = sum(focus_counts[(transition, "iid", "stay", state, r)] for r in route_filter)
                if ap + ip + ast + ist == 0:
                    continue
                actual_from = ap + ast
                iid_from = ip + ist
                rows.append(
                    {
                        "transition": transition,
                        "focus_state": state,
                        "entry_route": route,
                        "actual_from_mass": actual_from,
                        "iid_from_mass": iid_from,
                        "mass_delta": actual_from - iid_from,
                        "actual_pass_mass": ap,
                        "iid_pass_mass": ip,
                        "actual_stay_mass": ast,
                        "iid_stay_mass": ist,
                        "actual_pass_rate": ap / actual_from if actual_from else math.nan,
                        "iid_pass_rate": ip / iid_from if iid_from else math.nan,
                        "conditional_delta": (ap / actual_from if actual_from else math.nan) - (ip / iid_from if iid_from else math.nan),
                        "actual_stay_rate": ast / actual_from if actual_from else math.nan,
                        "iid_stay_rate": ist / iid_from if iid_from else math.nan,
                        "support": ap + ip + ast + ist,
                        "support_note": "low" if ap + ip + ast + ist < 0.002 else "ok",
                    }
                )
    return pd.DataFrame(rows).sort_values(["transition", "focus_state", "entry_route"])


def route_rows(route_counts) -> pd.DataFrame:
    rows = []
    for transition in [f"{a} -> {b}" for a, b in TRANSITIONS]:
        routes = sorted({route for t, _s, _g, _state, route in route_counts if t == transition})
        for route in ["ALL", *routes]:
            route_filter = routes if route == "ALL" else [route]
            ap = sum(route_counts[(transition, "actual", "pass", "ALL", r)] for r in route_filter)
            ip = sum(route_counts[(transition, "iid", "pass", "ALL", r)] for r in route_filter)
            ast = sum(route_counts[(transition, "actual", "stay", "ALL", r)] for r in route_filter)
            ist = sum(route_counts[(transition, "iid", "stay", "ALL", r)] for r in route_filter)
            if ap + ip + ast + ist == 0:
                continue
            actual_from = ap + ast
            iid_from = ip + ist
            rows.append(
                {
                    "transition": transition,
                    "entry_route": route,
                    "actual_from_mass": actual_from,
                    "iid_from_mass": iid_from,
                    "mass_delta": actual_from - iid_from,
                    "actual_pass_mass": ap,
                    "iid_pass_mass": ip,
                    "actual_stay_mass": ast,
                    "iid_stay_mass": ist,
                    "actual_pass_rate": ap / actual_from if actual_from else math.nan,
                    "iid_pass_rate": ip / iid_from if iid_from else math.nan,
                    "conditional_delta": (ap / actual_from if actual_from else math.nan) - (ip / iid_from if iid_from else math.nan),
                    "support": ap + ip + ast + ist,
                    "support_note": "low" if ap + ip + ast + ist < 0.002 else "ok",
                }
            )
    return pd.DataFrame(rows).sort_values(["transition", "entry_route"])


def examples_df(example_keep) -> pd.DataFrame:
    rows = []
    for key, items in example_keep.items():
        items.sort(key=lambda x: x[0], reverse=True)
        rows.extend(item for _weight, item in items[:12])
    return pd.DataFrame(rows)


def fmt(x: float) -> str:
    if pd.isna(x):
        return "nan"
    return f"{x:.6g}"


def build_report(focus: pd.DataFrame, route: pd.DataFrame, features: pd.DataFrame) -> str:
    lines = [
        "# Paradoxical Sequence Analysis",
        "",
        "## Purpose",
        "",
        "This is an observational sequence-level decomposition of the remaining_K inverse-sign pattern: mass deficit in a band, but conditional excess for the downstream transition. It does not claim a mechanism.",
        "",
        "Target transitions:",
        "",
        "- `96-127 -> 64-95`",
        "- `64-95 -> 32-63`",
        "- `32-63 -> 16-31`",
        "",
        "## Overall Results",
        "",
    ]
    all_route = route[route["entry_route"] == "ALL"]
    for r in all_route.itertuples(index=False):
        lines.append(
            f"- `{r.transition}`: from-mass delta `{fmt(r.mass_delta)}`, actual pass rate `{fmt(r.actual_pass_rate)}`, iid pass rate `{fmt(r.iid_pass_rate)}`, conditional delta `{fmt(r.conditional_delta)}`, support `{fmt(r.support)}`."
        )
    lines.extend(
        [
            "",
            "The inverse-sign structure is present in all three target transitions in this sample. The strongest conditional excess is upstream at `96-127 -> 64-95`, while the largest from-band mass deficit remains lower, especially around `32-63` and `64-95`.",
            "",
            "## Transition Notes",
            "",
        ]
    )
    for transition in [f"{a} -> {b}" for a, b in TRANSITIONS]:
        lines.append(f"### {transition}")
        sub_route = route[route["transition"] == transition].sort_values("entry_route")
        for r in sub_route.itertuples(index=False):
            if r.entry_route == "ALL":
                continue
            lines.append(
                f"- route `{r.entry_route}`: actual from `{fmt(r.actual_from_mass)}`, iid from `{fmt(r.iid_from_mass)}`, actual pass rate `{fmt(r.actual_pass_rate)}`, iid pass rate `{fmt(r.iid_pass_rate)}`, conditional delta `{fmt(r.conditional_delta)}` ({r.support_note} support)."
            )
        top = (
            features[(features["transition"] == transition) & (features["feature_type"].isin(["pre_k_window_3", "pre_k_window_5", "transition_k", "k_prefix_3"]))]
            .sort_values("pass_share_delta", ascending=False)
            .head(8)
        )
        lines.append("")
        lines.append("Top pass-side pattern concentrations:")
        for r in top.itertuples(index=False):
            lines.append(
                f"- `{r.feature_type}` `{r.pattern}`: pass share delta `{fmt(r.pass_share_delta)}`, support `{fmt(r.support)}` ({r.support_note})."
            )
        lines.append("")
    lines.extend(
        [
            "## Focus State",
            "",
            "The requested focus state is `late_growth|deep_32_63|even`. Rows below use all entry routes combined.",
            "",
        ]
    )
    req = focus[(focus["focus_state"] == "late_growth|deep_32_63|even") & (focus["entry_route"] == "ALL")]
    for r in req.itertuples(index=False):
        lines.append(
            f"- `{r.transition}`: actual from `{fmt(r.actual_from_mass)}`, iid from `{fmt(r.iid_from_mass)}`, mass delta `{fmt(r.mass_delta)}`, actual pass rate `{fmt(r.actual_pass_rate)}`, iid pass rate `{fmt(r.iid_pass_rate)}`, conditional delta `{fmt(r.conditional_delta)}`."
        )
    lines.extend(
        [
            "",
            "## Observational Limits",
            "",
            "- This uses the same sampled long-word/final-state-defined universe as the remaining_K chain check, not an exhaustive all-word enumeration.",
            "- Feature rows are weighted occurrence summaries. Rolling windows can count multiple windows per occurrence, so they should be read as concentration diagnostics rather than exclusive decompositions.",
            "- Low-support rows are marked in the CSV files. They are useful for generating hypotheses, but should not be used as stable claims.",
            "- The analysis identifies sequence groups that carry the inverse-sign observation; it does not prove that those groups cause the discrepancy.",
            "",
            "## Next Candidates",
            "",
            "1. Re-run only the strongest feature families with a larger iid sample to test stability.",
            "2. Split `START_IN_LAYER` vs `INFLOW_FROM_*` inside `late_growth|deep_32_63|even` for the `64-95 -> 32-63` and `32-63 -> 16-31` pair.",
            "3. Add prefix-cylinder labels to the representative examples if a paper-facing sequence table is needed.",
        ]
    )
    return "\n".join(lines) + "\n"


def main() -> None:
    OUT.mkdir(exist_ok=True)
    module = load_source()
    rng = random.Random(SEED)
    tau_cuts: dict[int, tuple[int, int]] = {}
    feature_counts = defaultdict(float)
    focus_counts = defaultdict(float)
    route_counts = defaultdict(float)
    example_keep = defaultdict(list)

    print("sampling iid", flush=True)
    for h in HS:
        sampled = sample_iid(module, h, rng)
        tau_cuts[h] = weighted_tau_cuts(sampled)
        for word, weight in sampled:
            if len(word) <= tau_cuts[h][1]:
                continue
            for power in POWERS:
                analyze_word(word, "iid", power, h, weight, feature_counts, focus_counts, route_counts, None, example_keep)
        print(f"iid h={h} done", flush=True)

    print("sampling actual", flush=True)
    for power in POWERS:
        status = load_status(power)
        for h in HS:
            _cut1, cut2 = tau_cuts[h]
            lo, hi, total = module.layer_bounds(power, h)
            escape_indices = [idx for idx in range(lo >> 1, (hi >> 1) + 1) if status[idx] == module.ESCAPE]
            chosen = evenly_spaced(escape_indices, ACTUAL_SAMPLE_PER_PH)
            sample_weight = len(escape_indices) / (len(chosen) * total) if chosen else 0.0
            for idx in chosen:
                word = module.trace_escape(2 * idx + 1, 1 << power)
                if len(word) <= cut2:
                    continue
                analyze_word(word, "actual", power, h, sample_weight, feature_counts, focus_counts, route_counts, None, example_keep)
            print(f"actual power={power} h={h} done", flush=True)
        del status

    print("writing outputs", flush=True)
    features = feature_rows(feature_counts)
    focus = focus_rows(focus_counts)
    route = route_rows(route_counts)
    examples = examples_df(example_keep)

    features.to_csv(OUT / "paradoxical_sequence_k_window_delta.csv", index=False, encoding="utf-8-sig")
    focus.to_csv(OUT / "paradoxical_sequence_by_focus_state.csv", index=False, encoding="utf-8-sig")
    route.to_csv(OUT / "paradoxical_sequence_entry_route.csv", index=False, encoding="utf-8-sig")
    examples.to_csv(OUT / "paradoxical_sequence_transition_examples.csv", index=False, encoding="utf-8-sig")
    (OUT / "paradoxical_sequence_report.md").write_text(build_report(focus, route, features), encoding="utf-8")
    print(OUT / "paradoxical_sequence_report.md", flush=True)


if __name__ == "__main__":
    main()
