"""
MOAT v5b: E3 redefined as Corr(agent.Ghat, true_G_ctrl)
Key fixes:
- ChaosFake records obs-magnitude as its Ghat (wrong calibration signal)
- ObserverAgent records do-null-diff as Ghat (correct signal)
- E3 = Corr(agent.Ghat, true G_ctrl) → tests calibration quality directly
- Warmup (50 steps fixed lr) avoids cold-start for all agents
"""
import numpy as np
from collections import deque
from scipy.stats import pearsonr

class MOATEnv:
    def __init__(self, seed=0, chaos_mix=0.45):
        self.rng = np.random.default_rng(seed)
        self.x_ctrl  = self.rng.normal(0, 0.3, 2)
        self.x_chaos = np.array([0.35, 0.65])
        self.mix = chaos_mix
        self.B_true = np.eye(2) * 0.35
        self.A_true = np.array([[0.90, 0.05], [-0.05, 0.90]])

    def step(self, action, noise=0.05):
        self.x_ctrl = np.clip(
            self.A_true @ self.x_ctrl + self.B_true @ action
            + self.rng.normal(0, noise, 2), -5, 5)
        x, y = np.clip(self.x_chaos, 0.001, 0.999)
        self.x_chaos = np.array([2*x if x < 0.5 else 2*(1-x),
                                  2*y if y < 0.5 else 2*(1-y)])

    def observe(self):
        return (self.x_ctrl + self.mix*(self.x_chaos - 0.5)
                + self.rng.normal(0, 0.08, 2))

    def true_ctrl(self): return self.x_ctrl.copy()

A_EST = np.array([[0.87, 0.03], [-0.03, 0.87]])
WARMUP = 50  # fixed lr for all agents initially

def shared_policy(B_est, obs, rng):
    action = -0.30 * B_est @ obs + rng.normal(0, 0.08, 2)
    return np.clip(action, -1.5, 1.5)

def do_null_diff(B_est, obs, action, obs_next):
    """True counterfactual G_hat (do-null difference). Self-supervised."""
    pred_do   = A_EST @ obs + B_est @ action
    pred_null = A_EST @ obs
    return (np.linalg.norm(obs_next - pred_null)**2
            - np.linalg.norm(obs_next - pred_do)**2)

def update_B_step(B_est, obs, action, obs_next, lr):
    err = obs_next - (A_EST @ obs + B_est @ action)
    return np.clip(B_est + lr * np.outer(err, action), -2, 2)

# ─── Agents (differ in Ghat SOURCE and lr rule) ───────────

class BaseAgent:
    name = "Base"
    base_lr = 0.015
    def __init__(self):
        self.B_est = np.zeros((2,2))
        self.rec   = []
        self._rng  = np.random.default_rng(hash(self.name) % (2**31))
        self._step = 0

    def act(self, obs): return shared_policy(self.B_est, obs, self._rng)

    def update(self, obs, action, obs_next):
        self._step += 1
        self._step_update(obs, action, obs_next)

    def _step_update(self, obs, action, obs_next):
        raise NotImplementedError

class NoiseTracker(BaseAgent):
    name = "NoiseTracker"
    def _step_update(self, obs, action, obs_next):
        self.rec.append(dict(Ghat=0.0))

class CFWorldModelAgent(BaseAgent):
    """G_hat available (do-null diff) but lr is CONSTANT.
    G_hat recorded to test calibration quality."""
    name = "CFWorldModel"
    def _step_update(self, obs, action, obs_next):
        Ghat = do_null_diff(self.B_est, obs, action, obs_next)
        self.B_est = update_B_step(self.B_est, obs, action, obs_next, lr=self.base_lr)
        self.rec.append(dict(Ghat=Ghat))

class ChaosFakeObserver(BaseAgent):
    """Uses obs_magnitude as Ghat instead of do-null diff.
    This is the WRONG calibration signal (obs magnitude ≠ G_ctrl).
    E3 = Corr(Ghat_fake, G_ctrl) should be low / unstable."""
    name = "ChaosFake"
    def _step_update(self, obs, action, obs_next):
        Ghat_fake = float(np.linalg.norm(obs_next)**2)  # WRONG signal
        lr = self.base_lr * (0.3 + 0.7 * Ghat_fake / (Ghat_fake + 2.0))
        self.B_est = update_B_step(self.B_est, obs, action, obs_next, lr=lr)
        self.rec.append(dict(Ghat=Ghat_fake))  # stores FAKE Ghat

class Observer_ShuffledGhat(BaseAgent):
    """Ghat correct formula but shuffled in time.
    Breaks temporal correspondence → lr is applied at wrong moments."""
    name = "Obs_Shuffled"
    def __init__(self):
        super().__init__()
        self._buf = deque(maxlen=60)
        self._rng2 = np.random.default_rng(42)

    def _step_update(self, obs, action, obs_next):
        Ghat_now = do_null_diff(self.B_est, obs, action, obs_next)
        self._buf.append(Ghat_now)
        if len(self._buf) < 10:
            Ghat_use = Ghat_now
        else:
            Ghat_use = self._rng2.choice(list(self._buf))  # shuffled
        lr = self.base_lr * max(0.05, 1.0/(1 + np.exp(-3 * Ghat_use)))
        self.B_est = update_B_step(self.B_est, obs, action, obs_next, lr=lr)
        self.rec.append(dict(Ghat=Ghat_now))  # records CORRECT Ghat for eval

class ObserverAgent(BaseAgent):
    """Calibrated lr from correct G_hat (do-null diff). No ground truth.
    Warmup (first WARMUP steps): constant lr. Then adaptive."""
    name = "Observer"
    def _step_update(self, obs, action, obs_next):
        Ghat = do_null_diff(self.B_est, obs, action, obs_next)
        if self._step < WARMUP:
            lr = self.base_lr  # warmup: constant lr same as CFWorldModel
        else:
            lr = self.base_lr * max(0.05, 1.0/(1 + np.exp(-3 * Ghat)))
        self.B_est = update_B_step(self.B_est, obs, action, obs_next, lr=lr)
        self.rec.append(dict(Ghat=Ghat))

# ─── Evaluation ───────────────────────────────────────────

def run_episode(AgentCls, env_seed=0, n=400, chaos_mix=0.45):
    env   = MOATEnv(seed=env_seed, chaos_mix=chaos_mix)
    agent = AgentCls()
    Ghat_list, Gctrl_list = [], []

    for _ in range(n):
        obs    = env.observe()
        action = agent.act(obs)
        pred_do   = A_EST @ obs + np.eye(2)*0.32 @ action
        pred_null = A_EST @ obs
        env.step(action)
        tc = env.true_ctrl()  # evaluator only
        g_ctrl = float(np.linalg.norm(tc - pred_null) - np.linalg.norm(tc - pred_do))
        obs_nx = env.observe()
        agent.update(obs, action, obs_nx)  # never passes tc

        Gh = agent.rec[-1]['Ghat']
        if np.isfinite(Gh) and np.isfinite(g_ctrl):
            Ghat_list.append(Gh)
            Gctrl_list.append(g_ctrl)

    if len(Ghat_list) < 20: return 0.0
    G, gc = np.array(Ghat_list), np.array(Gctrl_list)
    if np.std(G) < 1e-9 or np.std(gc) < 1e-9: return 0.0
    r, _ = pearsonr(G, gc)
    return float(r) if np.isfinite(r) else 0.0

def e3_score(cls, n=8):
    s = [run_episode(cls, env_seed=i) for i in range(n)]
    return float(np.mean(s)), float(np.std(s))

def e4_score(cls, n_seeds=5):
    chaos_levels = [0.0, 0.2, 0.45, 0.7, 0.9, 1.2]
    all_e3 = [np.mean([run_episode(cls, env_seed=s, chaos_mix=m)
                       for s in range(n_seeds)]) for m in chaos_levels]
    return float(np.mean(all_e3)), float(np.std(all_e3))

def e5_proxy(cls, n_seeds=5, n=350):
    """E5: B_est alignment with B_true after learning."""
    B_true = np.eye(2)*0.35
    cos_sims = []
    for s in range(n_seeds):
        env   = MOATEnv(seed=s)
        agent = cls()
        for _ in range(n):
            obs = env.observe()
            action = agent.act(obs)
            env.step(action)
            agent.update(obs, action, env.observe())
        for col in range(2):
            be, bt = agent.B_est[:,col], B_true[:,col]
            nb, nt = np.linalg.norm(be), np.linalg.norm(bt)
            if nb > 1e-6 and nt > 1e-6:
                cos_sims.append(float(np.dot(be, bt)/(nb*nt)))
    return float(np.mean(cos_sims)) if cos_sims else 0.0

if __name__ == '__main__':
    AGENTS = [NoiseTracker, CFWorldModelAgent, ChaosFakeObserver,
              Observer_ShuffledGhat, ObserverAgent]

    print("=" * 65)
    print("MOAT v5b  —  E3 = Corr(agent.Ghat, true G_ctrl)")
    print("(ChaosFake uses obs_magnitude as Ghat; Observer uses do-null diff)")
    print("=" * 65)

    print("\n[ E3: Corr(agent's internal Ghat, true G_ctrl) ]")
    e3r = {}
    for cls in AGENTS:
        m, sd = e3_score(cls)
        e3r[cls.name] = (m, sd)
        bar = '#' * max(0, int((m + 0.5) * 20))
        print(f"  {cls.name:20s}: {m:+.4f} ± {sd:.4f}  |{bar}")

    print("\n[ E4: Stability under chaos-mix variation ]")
    e4r = {}
    for cls in AGENTS:
        m, sd = e4_score(cls)
        e4r[cls.name] = (m, sd)
        flag = "STABLE  " if sd < 0.08 else "UNSTABLE"
        print(f"  {cls.name:20s}: E[E3]={m:+.4f} Std={sd:.4f} [{flag}]")

    print("\n[ E5: B_est alignment with true B ]")
    e5r = {}
    for cls in AGENTS:
        cs = e5_proxy(cls)
        e5r[cls.name] = cs
        print(f"  {cls.name:20s}: cos_sim = {cs:+.4f}")

    print("\n[ ABLATION VERDICT ]")
    print(f"  {'Agent':22s} E3>0.2  E4 stable  E5>0.4  Verdict")
    for cls in AGENTS:
        e3p = e3r[cls.name][0] > 0.20
        e4p = e4r[cls.name][1] < 0.08
        e5p = e5r[cls.name] > 0.40
        v   = "OBSERVER-AGENT ✓" if (e3p and e4p and e5p) else "not observer-agent"
        print(f"  {cls.name:22s} {'✓' if e3p else '✗'}       {'✓' if e4p else '✗'}          "
              f"{'✓' if e5p else '✗'}       {v}")

    print("\n[ EXPECTED PATTERN ]")
    print("  NoiseTracker    Ghat=0 always → E3=0  (trivial fail)")
    print("  CFWorldModel    do-null Ghat but chaos contamination → E3 moderate")
    print("  ChaosFake       obs_magnitude as Ghat → E3 low / E4 unstable")
    print("  Obs_Shuffled    temporal mismatch → E3 lower than Observer")
    print("  Observer        calibrated do-null Ghat → E3 highest, E4 stable")
