"""
MOAT v3: Random exploration policy (shared across all agents).
Agents differ ONLY in how they learn their action-sensitivity model B_est.
dcf = ||B_est @ action|| (agent-specific, since B_est differs).
G_ctrl = truth (evaluator-only).
This isolates model quality from action-selection quality.
"""
import numpy as np
from scipy.stats import pearsonr

class MOATEnv:
    def __init__(self, seed=0):
        self.rng = np.random.default_rng(seed)
        self.x_ctrl  = self.rng.normal(0, 0.3, 2)
        self.x_chaos = np.array([0.3, 0.7])

    def step(self, action, noise=0.05):
        A = np.array([[0.90, 0.05], [-0.05, 0.90]])
        B = np.eye(2) * 0.35
        self.x_ctrl = np.clip(A @ self.x_ctrl + B @ action + self.rng.normal(0, noise, 2), -5, 5)
        x, y = np.clip(self.x_chaos, 0.001, 0.999)
        self.x_chaos = np.array([2*x if x < 0.5 else 2*(1-x),
                                  2*y if y < 0.5 else 2*(1-y)])

    def observe(self):
        return self.x_ctrl + 0.45*(self.x_chaos - 0.5) + self.rng.normal(0, 0.08, 2)

    def true_ctrl(self): return self.x_ctrl.copy()

A_shared = np.array([[0.87, 0.03], [-0.03, 0.87]])

# ─── Agent base: receives shared pre-generated actions ─────

class AgentBase:
    name = "Base"
    def __init__(self):
        self.B_est = np.zeros((2, 2))  # learned action sensitivity
        self.rec   = []

    def dcf(self, obs, action):
        return float(np.linalg.norm(self.B_est @ action))

    def update(self, obs, action, obs_next, g_ctrl_signal=0.0):
        """Update B_est; concrete implementations differ."""
        raise NotImplementedError

class NoiseTracker(AgentBase):
    name = "NoiseTracker"
    def update(self, obs, action, obs_next, g_ctrl_signal=0.0):
        # No action model at all: B_est stays zero
        self.rec.append(dict(dcf=0.0))

class CFWorldModelAgent(AgentBase):
    """Learns B_est via symmetric gradient (no G_ctrl calibration).
    Has dcf (E2'✓) but model update is G_ctrl-blind (E3''✗)."""
    name = "CFWorldModel"
    def update(self, obs, action, obs_next, g_ctrl_signal=0.0):
        dcf_val = self.dcf(obs, action)
        pred  = A_shared @ obs + self.B_est @ action
        err   = obs_next - pred
        self.B_est += 0.015 * np.outer(err, action)  # fixed rate, ignores g_ctrl
        self.B_est  = np.clip(self.B_est, -3, 3)
        self.rec.append(dict(dcf=dcf_val))

class ChaosFakeObserver(AgentBase):
    """Inflates B_est by obs variance: chaos → large obs → large B_est.
    dcf inflated in chaos directions. E3 unstable under policy/env perturbation."""
    name = "ChaosFake"
    alpha = 0.40

    def update(self, obs, action, obs_next, g_ctrl_signal=0.0):
        dcf_val = self.dcf(obs, action)
        # False attribution: obs magnitude → action sensitivity
        eff_B = self.B_est.copy()
        obs_inflation = self.alpha * np.diag(np.abs(obs_next))
        eff_B += 0.015 * obs_inflation  # chaos-inflated update
        self.B_est = np.clip(eff_B, -3, 3)
        self.rec.append(dict(dcf=dcf_val))

class ObserverAgent(AgentBase):
    """Calibrated B_est update: learn action sensitivity ONLY when action helped.
    G_ctrl signal selectively reinforces controllable directions."""
    name = "Observer"
    def update(self, obs, action, obs_next, g_ctrl_signal=0.0):
        dcf_val = self.dcf(obs, action)
        pred = A_shared @ obs + self.B_est @ action
        err  = obs_next - pred
        # Calibrated lr: higher when action genuinely helped (positive g_ctrl)
        lr = 0.015 * max(0.0, g_ctrl_signal) / (abs(g_ctrl_signal) + 0.3)
        lr = max(lr, 0.002)  # min base lr so B_est can shrink
        self.B_est += lr * np.outer(err, action)
        self.B_est  = np.clip(self.B_est, -3, 3)
        self.rec.append(dict(dcf=dcf_val))

# ─── Evaluation ───────────────────────────────────────────

def run_episode(AgentCls, env_seed=0, n=400, action_noise_scale=1.0):
    env    = MOATEnv(seed=env_seed)
    agent  = AgentCls()
    rng    = np.random.default_rng(env_seed + 33)
    dcf_list, g_ctrl_list = [], []
    prev_g = 0.0

    for _ in range(n):
        obs    = env.observe()
        # SHARED random exploration policy: same for all agents
        action = np.clip(action_noise_scale * rng.normal(0, 0.6, 2), -1.5, 1.5)

        pred_do   = A_shared @ obs + np.eye(2)*0.32 @ action
        pred_null = A_shared @ obs
        env.step(action)
        tc     = env.true_ctrl()
        g_ctrl = float(np.linalg.norm(tc - pred_null) - np.linalg.norm(tc - pred_do))
        obs_nx = env.observe()

        agent.update(obs, action, obs_nx, g_ctrl_signal=prev_g)
        prev_g = g_ctrl

        d = agent.rec[-1]['dcf']
        if np.isfinite(d) and np.isfinite(g_ctrl):
            dcf_list.append(d)
            g_ctrl_list.append(g_ctrl)

    if len(dcf_list) < 20: return 0.0
    d, g = np.array(dcf_list), np.array(g_ctrl_list)
    if np.std(d) < 1e-9 or np.std(g) < 1e-9: return 0.0
    r, _ = pearsonr(d, g)
    return float(r) if np.isfinite(r) else 0.0

def e3_score(cls, n=8):
    s = [run_episode(cls, env_seed=i) for i in range(n)]
    return float(np.mean(s)), float(np.std(s))

def e4_score(cls, n_seeds=5):
    # Perturbation: vary action noise scale (analogous to policy perturbation)
    scales = [0.5, 0.8, 1.0, 1.3, 1.7, 2.2]
    all_e3 = [np.mean([run_episode(cls, env_seed=s, action_noise_scale=sc)
                       for s in range(n_seeds)]) for sc in scales]
    return float(np.mean(all_e3)), float(np.std(all_e3))

if __name__ == '__main__':
    AGENTS = [NoiseTracker, CFWorldModelAgent, ChaosFakeObserver, ObserverAgent]
    print("=" * 65)
    print("MOAT v3: Shared exploration, agent-specific B_est learning")
    print("=" * 65)
    print("\n[ E3: Cov(||B_est @ action||, G_ctrl) ]")
    e3r = {}
    for cls in AGENTS:
        mean, std = e3_score(cls)
        e3r[cls.name] = (mean, std)
        bar = '#' * max(0, int((mean + 0.1) * 35))
        print(f"  {cls.name:20s}: {mean:+.4f} ± {std:.4f}  |{bar}")

    print("\n[ E4': Stability across exploration intensity ]")
    e4r = {}
    for cls in AGENTS:
        mean, std = e4_score(cls)
        e4r[cls.name] = (mean, std)
        flag = "STABLE  " if std < 0.12 else "UNSTABLE"
        print(f"  {cls.name:20s}: E[E3]={mean:+.4f}  Std={std:.4f}  [{flag}]")

    print("\n[ Verdict ]")
    for cls in AGENTS:
        e3_pass = e3r[cls.name][0] > 0.10
        e4_pass = e4r[cls.name][1] < 0.12
        v = "OBSERVER-AGENT" if (e3_pass and e4_pass) else "not observer-agent"
        print(f"  {cls.name:20s}  E3{'✓' if e3_pass else '✗'}  E4{'✓' if e4_pass else '✗'}  {v}")

# ─── E4 variant B: vary chaos intensity ────────────────────

def run_episode_chaos_mix(AgentCls, env_seed=0, n=400, chaos_mix=0.45):
    """Same as run_episode but with variable chaos mixing coefficient."""
    env    = MOATEnv(seed=env_seed)
    agent  = AgentCls()
    rng    = np.random.default_rng(env_seed + 33)
    dcf_list, g_ctrl_list = [], []
    prev_g = 0.0

    for _ in range(n):
        # Use chaos_mix parameter instead of fixed 0.45
        obs_ctrl  = env.x_ctrl
        obs_chaos = (env.x_chaos - 0.5)
        obs = obs_ctrl + chaos_mix * obs_chaos + rng.normal(0, 0.08, 2)

        action = np.clip(rng.normal(0, 0.6, 2), -1.5, 1.5)
        pred_do   = A_shared @ obs + np.eye(2)*0.32 @ action
        pred_null = A_shared @ obs
        env.step(action)
        tc = env.true_ctrl()
        g_ctrl = float(np.linalg.norm(tc - pred_null) - np.linalg.norm(tc - pred_do))
        obs_nx = obs_ctrl + chaos_mix * obs_chaos + rng.normal(0, 0.08, 2)

        agent.update(obs, action, obs_nx, g_ctrl_signal=prev_g)
        prev_g = g_ctrl
        d = agent.rec[-1]['dcf']
        if np.isfinite(d) and np.isfinite(g_ctrl):
            dcf_list.append(d); g_ctrl_list.append(g_ctrl)

    if len(dcf_list) < 20: return 0.0
    d, g = np.array(dcf_list), np.array(g_ctrl_list)
    if np.std(d) < 1e-9 or np.std(g) < 1e-9: return 0.0
    r, _ = pearsonr(d, g)
    return float(r) if np.isfinite(r) else 0.0

def e4b_score(cls, n_seeds=5):
    """E4b: stability when chaos mixing intensity varies."""
    chaos_levels = [0.0, 0.2, 0.45, 0.7, 0.9, 1.2]
    all_e3 = [np.mean([run_episode_chaos_mix(cls, env_seed=s, chaos_mix=cm)
                       for s in range(n_seeds)]) for cm in chaos_levels]
    return float(np.mean(all_e3)), float(np.std(all_e3))

print("\n[ E4b: Stability under CHAOS INTENSITY variation ]")
print("  (Observer should be stable; ChaosFake should degrade when chaos increases)")
for cls in AGENTS:
    mean, std = e4b_score(cls)
    flag = "STABLE  " if std < 0.10 else "UNSTABLE"
    print(f"  {cls.name:20s}: E[E3]={mean:+.4f}  Std={std:.4f}  [{flag}]")
