"""
MOAT v5c: Causal Attribution Separation Test

Core hypothesis (separation theorem candidate):
  Observer-agent uses do-null comparison → adapts ONLY to action-attributable changes
  Generic adaptive systems use total prediction error → adapt to ALL changes (ctrl + chaos)

Environment phases:
  t=  0-100: B_true=I*0.35, chaos_mix=0.3  (baseline)
  t=100-200: B_true ROTATED 90°, chaos_mix=0.3  (controllable change)
  t=200-300: B_true same (rotated), chaos_mix=0.9  (uncontrollable change only)
  t=300-400: B_true=I*0.35, chaos_mix=0.9  (controllable change again)

Expected:
  Observer:      fast adaptation at t=100/300 (do-null G_hat goes negative → high lr)
                 stable at t=200 (chaos doesn't fool do-null comparison)
  KalmanLike:    adapts to all changes equally (no do-null distinction)
  MetaAdaptive:  error-magnitude adaptive lr adapts to chaos too (no causal filter)
"""

import numpy as np
from scipy.stats import pearsonr

# ─── Non-stationary environment ──────────────────────────

def rot90():
    """90 degree rotation matrix."""
    return np.array([[0, -1], [1, 0]], dtype=float)

class MOATEnvV5c:
    def __init__(self, seed=0):
        self.rng = np.random.default_rng(seed)
        self.x_ctrl  = self.rng.normal(0, 0.3, 2)
        self.x_chaos = np.array([0.35, 0.65])
        self._t = 0
        # True dynamics — evaluator only
        self.B_phases = [
            (0,   100, np.eye(2)*0.35,          0.3),  # baseline
            (100, 200, rot90() @ np.eye(2)*0.35, 0.3),  # B_true rotated
            (200, 300, rot90() @ np.eye(2)*0.35, 0.9),  # chaos increases
            (300, 400, np.eye(2)*0.35,           0.9),  # B_true back
        ]
        self.A_true = np.array([[0.90, 0.05], [-0.05, 0.90]])

    def _get_phase(self):
        for start, end, B, mix in self.B_phases:
            if start <= self._t < end:
                return B, mix
        return self.B_phases[-1][2:]

    def step(self, action, noise=0.05):
        B, _ = self._get_phase()
        self.x_ctrl = np.clip(
            self.A_true @ self.x_ctrl + B @ action
            + self.rng.normal(0, noise, 2), -5, 5)
        x, y = np.clip(self.x_chaos, 0.001, 0.999)
        self.x_chaos = np.array([2*x if x < 0.5 else 2*(1-x),
                                  2*y if y < 0.5 else 2*(1-y)])
        self._t += 1

    def observe(self):
        _, mix = self._get_phase()
        return self.x_ctrl + mix*(self.x_chaos - 0.5) + self.rng.normal(0, 0.08, 2)

    def true_ctrl(self): return self.x_ctrl.copy()
    def true_B(self):
        B, _ = self._get_phase()
        return B

A_EST = np.array([[0.87, 0.03], [-0.03, 0.87]])

def do_null_diff(B_est, obs, action, obs_next):
    """Action-attributed G_hat (causal, do-null)."""
    pred_do   = A_EST @ obs + B_est @ action
    pred_null = A_EST @ obs
    return (np.linalg.norm(obs_next - pred_null)**2
            - np.linalg.norm(obs_next - pred_do)**2)

def upd_B(B_est, obs, action, obs_next, lr):
    err = obs_next - (A_EST @ obs + B_est @ action)
    return np.clip(B_est + lr * np.outer(err, action), -3, 3)

# ─── Agents ───────────────────────────────────────────────

class BaseAgent:
    base_lr = 0.015
    def __init__(self, name):
        self.name = name
        self.B_est = np.zeros((2, 2))
        self._rng  = np.random.default_rng(hash(name) % (2**31))
        self._t    = 0

    def act(self, obs):
        return np.clip(-0.30 * self.B_est @ obs + self._rng.normal(0, 0.08, 2), -1.5, 1.5)

    def update(self, obs, action, obs_next):
        self._t += 1

def make_agents():
    return [
        KalmanLikeAgent(),
        MetaAdaptiveAgent(),
        CFWorldModelAgent(),
        ChaosFakeAgent(),
        ObserverAgent()
    ]

class KalmanLikeAgent(BaseAgent):
    """Pure obs-residual tracking. No do-null distinction.
    Updates B_est toward any prediction error (action-caused OR chaos-caused)."""
    def __init__(self): super().__init__("KalmanLike")
    def update(self, obs, action, obs_next):
        super().update(obs, action, obs_next)
        err = obs_next - (A_EST @ obs + self.B_est @ action)
        self.B_est = np.clip(self.B_est + self.base_lr * np.outer(err, action), -3, 3)

class MetaAdaptiveAgent(BaseAgent):
    """Error-magnitude adaptive lr (meta-learning style). No causal attribution.
    High error → high lr for ALL changes, including chaos."""
    def __init__(self):
        super().__init__("MetaAdaptive")
        self._recent_err = []

    def update(self, obs, action, obs_next):
        super().update(obs, action, obs_next)
        err = obs_next - (A_EST @ obs + self.B_est @ action)
        err_mag = float(np.linalg.norm(err))
        self._recent_err = (self._recent_err + [err_mag])[-15:]
        mean_err = np.mean(self._recent_err) if self._recent_err else 0.1
        lr = self.base_lr * (0.5 + mean_err / (mean_err + 0.3))
        self.B_est = np.clip(self.B_est + lr * np.outer(err, action), -3, 3)

class CFWorldModelAgent(BaseAgent):
    """Has do-null G_hat but constant lr (no calibration)."""
    def __init__(self): super().__init__("CFWorldModel")
    def update(self, obs, action, obs_next):
        super().update(obs, action, obs_next)
        self.B_est = upd_B(self.B_est, obs, action, obs_next, lr=self.base_lr)

class ChaosFakeAgent(BaseAgent):
    """Updates B_est using obs_magnitude as false signal."""
    def __init__(self): super().__init__("ChaosFake")
    def update(self, obs, action, obs_next):
        super().update(obs, action, obs_next)
        obs_mag = float(np.linalg.norm(obs_next)**2)
        lr = self.base_lr * (0.3 + 0.7 * obs_mag / (obs_mag + 2.0))
        self.B_est = upd_B(self.B_est, obs, action, obs_next, lr=lr)

class ObserverAgent(BaseAgent):
    """G_hat calibrated lr. Only updates aggressively when action DID help prediction.
    Causal filter: chaos-caused errors → G_hat ≈ 0 → low lr → B_est not contaminated."""
    def __init__(self): super().__init__("Observer")
    def update(self, obs, action, obs_next):
        super().update(obs, action, obs_next)
        Ghat = do_null_diff(self.B_est, obs, action, obs_next)
        lr = self.base_lr * max(0.05, 1.0 / (1 + np.exp(-4 * Ghat)))
        self.B_est = upd_B(self.B_est, obs, action, obs_next, lr=lr)

# ─── Evaluation ───────────────────────────────────────────

def b_sim(B_est, B_true):
    """Cosine similarity between estimated and true B (per column, averaged)."""
    sims = []
    for c in range(2):
        be, bt = B_est[:, c], B_true[:, c]
        nb, nt = np.linalg.norm(be), np.linalg.norm(bt)
        if nb > 1e-6 and nt > 1e-6:
            sims.append(float(np.dot(be, bt) / (nb * nt)))
    return float(np.mean(sims)) if sims else 0.0

def run_episode_track(AgentCls, env_seed=0, n=400):
    """Track B_est alignment per timestep to see adaptation dynamics."""
    env   = MOATEnvV5c(seed=env_seed)
    agent = AgentCls()
    cos_track = []

    for t in range(n):
        obs    = env.observe()
        action = agent.act(obs)
        env.step(action)
        obs_nx = env.observe()
        agent.update(obs, action, obs_nx)
        cos_track.append(b_sim(agent.B_est, env.true_B()))

    return np.array(cos_track)

def run_episode_metrics(AgentCls, env_seed=0, n=400):
    """Per-phase cos_sim and adaptation lag."""
    env   = MOATEnvV5c(seed=env_seed)
    agent = AgentCls()
    phase_sims = {100:[], 200:[], 300:[], 400:[]}
    prev_B = env.true_B().copy()
    adapt_starts = {100: None, 300: None}
    adapt_lag = {100: None, 300: None}

    for t in range(n):
        obs    = env.observe()
        action = agent.act(obs)
        env.step(action)
        obs_nx = env.observe()
        agent.update(obs, action, obs_nx)

        cs = b_sim(agent.B_est, env.true_B())
        for phase_end in [100, 200, 300, 400]:
            if t < phase_end:
                phase_sims[phase_end].append(cs)
                break

        # Measure adaptation lag after phase transitions
        if t == 100:
            adapt_starts[100] = 0.0  # start measuring
        if t == 300:
            adapt_starts[300] = 0.0
        for cp in [100, 300]:
            if adapt_starts.get(cp) is not None and adapt_lag[cp] is None:
                if cs > 0.80:
                    adapt_lag[cp] = t - cp

    return {
        'phase1_ctrl_change': np.mean(phase_sims[200][-30:]),  # last 30 of phase 1
        'phase2_chaos_robustness': np.mean(phase_sims[300][-30:]),  # stability under chaos
        'adapt_lag_100': adapt_lag[100] if adapt_lag[100] else 100,
        'adapt_lag_300': adapt_lag[300] if adapt_lag[300] else 100,
    }

def summarize(n_seeds=6):
    AGENTS = [KalmanLikeAgent, MetaAdaptiveAgent, CFWorldModelAgent,
              ChaosFakeAgent, ObserverAgent]

    print("=" * 70)
    print("MOAT v5c  —  Causal Attribution Separation Test")
    print("=" * 70)

    print("""
Environment phases:
  t=  0-100  B_true = I*0.35        chaos_mix=0.3   (baseline)
  t=100-200  B_true = ROT*0.35      chaos_mix=0.3   (controllable change)
  t=200-300  B_true = ROT*0.35      chaos_mix=0.9   (UNcontrollable change)
  t=300-400  B_true = I*0.35        chaos_mix=0.9   (controllable change again)
""")

    print("Negative controls:")
    print("  KalmanLike:    obs residual, constant lr — no causal attribution")
    print("  MetaAdaptive:  error-magnitude adaptive lr — chases ALL errors")
    print()

    # Track time series
    print("[ B_est alignment (cos_sim) per phase — averaged over seeds ]")
    phase_labels = ["t=0-100\n(baseline)",
                    "t=100-200\n(B rotated)",
                    "t=200-300\n(chaos up)",
                    "t=300-400\n(B back)"]

    tracks = {}
    for cls in AGENTS:
        seqs = [run_episode_track(cls, env_seed=s) for s in range(n_seeds)]
        tracks[cls.__name__] = np.mean(seqs, axis=0)

    # Print phase averages
    print(f"  {'Agent':18s}  p1-base  p2-Brot  p3-chaos  p4-Bback  adapt_lag")
    print("  " + "-"*65)
    for cls in AGENTS:
        tr = tracks[cls.__name__]
        p1 = np.mean(tr[50:100]);  p2 = np.mean(tr[150:200])
        p3 = np.mean(tr[250:300]); p4 = np.mean(tr[350:400])
        # Adaptation lag: steps to reach 0.75 after B_true changes at t=100
        lag = next((i for i in range(100, 200) if tr[i] > 0.75), 200) - 100
        print(f"  {cls.__name__:18s}  {p1:.3f}    {p2:.3f}    {p3:.3f}     {p4:.3f}     {lag}")

    print()
    print("[ Separation prediction ]")
    print("  Observer should show:")
    print("    - Fast adaptation after B rotation (low adapt_lag)")
    print("    - HIGH p3-chaos (not fooled by chaos increase)")
    print("  KalmanLike/MetaAdaptive should show:")
    print("    - Slower or similar adaptation (no causal filter)")
    print("    - LOWER p3-chaos (chaos contaminates their B_est)")

summarize()
