"""
4-qubit Observer-Agent Model  (v2 - clean)
===========================================
identity condition を直接測定できる形に再設計

Key insight:
  Controller:      gain = alpha (固定)
  Observer-Agent:  gain = f(Delta_cf) (Delta_cf 依存)

測定:  corr(Delta_cf[t],  gain[t])
  Controller   -> ~ 0
  ObserverAgent -> > 0  (設計通り)

また back-action の効果も明示的に記録する。
"""

import numpy as np
from scipy.linalg import expm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# ─── Pauli / qubit utilities ──────────────────────
I2 = np.eye(2, dtype=complex)
X  = np.array([[0,1],[1,0]], dtype=complex)
Z  = np.array([[1,0],[0,-1]], dtype=complex)

def kron4(a,b,c,d):
    return np.kron(np.kron(np.kron(a,b),c),d)

def cz(n, q1, q2):
    dim = 2**n
    M = np.eye(dim, dtype=complex)
    for i in range(dim):
        bits = [(i>>(n-1-k))&1 for k in range(n)]
        if bits[q1]==1 and bits[q2]==1:
            M[i,i] = -1
    return M

# Intervention operators (C_O)
OPS = [
    kron4(I2,I2,I2,I2),   # I
    kron4(X, I2,I2,I2),   # X on qubit-0 (self)
    kron4(Z, I2,I2,I2),   # Z on qubit-0 (self)
    cz(4, 0, 2),           # CZ self0 x env2
]
N_OPS = len(OPS)

# ─── Quantum ops ──────────────────────────────────
def ptrace_env(rho):
    """4-qubit -> 2-qubit (self) density matrix"""
    ds, de = 4, 4
    r = np.zeros((ds,ds), dtype=complex)
    for i in range(de):
        for j in range(ds):
            for k in range(ds):
                r[j,k] += rho[j*de+i, k*de+i]
    return r

def U_rho(rho, U):
    return U @ rho @ U.conj().T

def trace_dist(a, b):
    return 0.5 * np.sum(np.abs(np.linalg.eigvalsh(a-b)))

def normalize_dm(rho):
    rho = 0.5*(rho + rho.conj().T)
    t = np.real(np.trace(rho))
    return rho/t if t > 1e-12 else rho

def back_action(rho):
    """
    self-measurement back-action: S(rho) = sum_i P_i rho P_i
    P_i = |i><i|_self x I_env
    非可逆チャネル -> perspective の不完全性の源
    """
    ds, de = 4, 4
    out = np.zeros_like(rho)
    for i in range(ds):
        P = np.zeros((ds*de, ds*de), dtype=complex)
        for j in range(de):
            idx = i*de + j
            P[idx, idx] = 1.0
        out += P @ rho @ P
    return out

def delta_cf(rho, pi):
    """
    Delta_cf = sum_{a<a'} pi[a] pi[a'] D(rho^{do(a)}, rho^{do(a')})
    identity condition の核：介入間の未来分布の平均識別量
    """
    futures = [U_rho(rho, op) for op in OPS]
    d = 0.0
    for i in range(N_OPS):
        for j in range(i+1, N_OPS):
            d += pi[i]*pi[j]*trace_dist(futures[i], futures[j])
    return d

# ─── Agents ───────────────────────────────────────
class Controller:
    """
    Layer 1: Delta_cf を更新に使わない
    gain = alpha (定数)
    identity condition: NOT satisfied by design
    """
    ALPHA = 0.25

    def __init__(self):
        self.M = normalize_dm(np.eye(4, dtype=complex))
        self.gains      = []   # 実際の更新係数を記録
        self.delta_cf_t = []

    def step(self, rho, rng):
        pi = np.ones(N_OPS) / N_OPS
        dc = delta_cf(rho, pi)

        a = rng.choice(N_OPS, p=pi)
        rho2 = normalize_dm(U_rho(rho, OPS[a]))

        # Gain が Delta_cf に依存しない (固定)
        gain = self.ALPHA
        rho_s = normalize_dm(ptrace_env(rho2))
        self.M = normalize_dm((1-gain)*self.M + gain*rho_s)

        self.gains.append(gain)
        self.delta_cf_t.append(dc)
        return rho2


class ObserverAgent:
    """
    Layer 2: identity condition を満たす
    gain = f(Delta_cf) --- Delta_cf に正比例して大きくなる
    -> corr(Delta_cf, gain) >> 0
    """
    ALPHA = 0.5   # 最大 gain

    def __init__(self):
        self.M = normalize_dm(np.eye(4, dtype=complex))
        self.gains      = []
        self.delta_cf_t = []

    def step(self, rho, rng):
        pi = np.ones(N_OPS) / N_OPS
        dc = delta_cf(rho, pi)

        a = rng.choice(N_OPS, p=pi)
        rho2 = normalize_dm(U_rho(rho, OPS[a]))

        # ── identity condition の本体 ──
        # gain が Delta_cf に直接比例する
        gain = self.ALPHA * (1 - np.exp(-12 * dc))

        rho_s = normalize_dm(ptrace_env(rho2))
        self.M = normalize_dm((1-gain)*self.M + gain*rho_s)

        self.gains.append(gain)
        self.delta_cf_t.append(dc)
        return rho2


class PerspectivalObserver(ObserverAgent):
    """
    Layer 3: back-action を追加
    自分が未来を読もうとすると自分が変化する
    -> self-model が back-action を通して歪む
    -> 視点の不完全性（完全な外在化が不可能）
    """
    def __init__(self):
        super().__init__()
        self.backaction_magnitudes = []   # back-action の大きさを記録

    def step(self, rho, rng):
        pi = np.ones(N_OPS) / N_OPS
        dc = delta_cf(rho, pi)

        a = rng.choice(N_OPS, p=pi)
        rho2 = normalize_dm(U_rho(rho, OPS[a]))

        # back-action: 自己測定で rho が乱れる
        rho_disturbed = normalize_dm(back_action(rho2))
        ba_mag = trace_dist(rho2, rho_disturbed)   # 乱れの大きさ
        self.backaction_magnitudes.append(ba_mag)

        # identity condition は維持 (gain は Delta_cf 依存)
        gain = self.ALPHA * (1 - np.exp(-12 * dc))

        rho_s = normalize_dm(ptrace_env(rho_disturbed))
        self.M = normalize_dm((1-gain)*self.M + gain*rho_s)

        self.gains.append(gain)
        self.delta_cf_t.append(dc)
        return rho_disturbed   # back-action 後の状態で継続


# ─── Simulation ───────────────────────────────────
def simulate(n=100, seed=0):
    rng = np.random.default_rng(seed)

    psi = np.zeros(16, dtype=complex)
    psi[0] = 1.0
    rho0 = np.outer(psi, psi.conj())
    # 微小ノイズ
    N = rng.standard_normal((16,16)) + 1j*rng.standard_normal((16,16))
    N = N + N.conj().T; N *= 0.008/np.linalg.norm(N)
    rho0 = normalize_dm(rho0 + N)

    ctrl  = Controller()
    agent = ObserverAgent()
    persp = PerspectivalObserver()

    rc, ra, rp = rho0.copy(), rho0.copy(), rho0.copy()

    for _ in range(n):
        # 環境ノイズ
        th = rng.normal(0, 0.12)
        ph = rng.normal(0, 0.12)
        Un = expm(1j*th*kron4(I2,I2,X,I2) + 1j*ph*kron4(I2,I2,I2,Z))

        rc = normalize_dm(U_rho(ctrl.step(rc, rng),  Un))
        ra = normalize_dm(U_rho(agent.step(ra, rng), Un))
        rp = normalize_dm(U_rho(persp.step(rp, rng), Un))

    return ctrl, agent, persp


# ─── Metrics ──────────────────────────────────────
def ic_score(agent):
    """
    identity condition score
    = Pearson corr(Delta_cf[t], gain[t])
    Controller   -> ~ 0  (gain が定数なので相関がゼロ)
    ObserverAgent -> > 0  (gain = f(Delta_cf) なので正相関)
    """
    d = np.array(agent.delta_cf_t)
    g = np.array(agent.gains)
    if np.std(d)<1e-12 or np.std(g)<1e-12:
        return 0.0
    return float(np.corrcoef(d, g)[0,1])


# ─── Plot ─────────────────────────────────────────
def plot(ctrl, agent, persp, n):
    steps = np.arange(n)
    fig = plt.figure(figsize=(14,10))
    fig.patch.set_facecolor('#0a0a1a')
    gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.5, wspace=0.35)

    C = {'ctrl':'#5599ff', 'agent':'#ff7744', 'persp':'#44ffaa'}
    L = {'ctrl':'Controller (Layer 1)',
         'agent':'Observer-Agent (Layer 2)',
         'persp':'Perspectival Observer (Layer 3)'}

    def style(ax, title):
        ax.set_facecolor('#0f0f2a')
        ax.set_title(title, color='white', pad=7, fontsize=9)
        ax.tick_params(colors='#888888', labelsize=8)
        for sp in ax.spines.values(): sp.set_color('#333355')

    # Panel 1: Delta_cf over time
    ax = fig.add_subplot(gs[0,:])
    for k,obj in [('ctrl',ctrl),('agent',agent),('persp',persp)]:
        ax.plot(steps, obj.delta_cf_t, color=C[k], lw=1.4, alpha=0.8, label=L[k])
    style(ax, 'Delta_cf  (intervention-conditioned future distinguishability)')
    ax.set_xlabel('step', color='#aaa', fontsize=8)
    ax.set_ylabel('Delta_cf', color='#aaa', fontsize=8)
    ax.legend(fontsize=7.5, facecolor='#111133', labelcolor='white')

    # Panel 2: gain over time
    ax = fig.add_subplot(gs[1,0])
    for k,obj in [('ctrl',ctrl),('agent',agent),('persp',persp)]:
        ax.plot(steps, obj.gains, color=C[k], lw=1.4, alpha=0.8, label=L[k])
    style(ax, 'Gain applied to M update\n(Agent: gain = f(Delta_cf);  Ctrl: constant)')
    ax.set_xlabel('step', color='#aaa', fontsize=8)
    ax.set_ylabel('gain', color='#aaa', fontsize=8)
    ax.legend(fontsize=7, facecolor='#111133', labelcolor='white')

    # Panel 3: Delta_cf vs gain scatter
    ax = fig.add_subplot(gs[1,1])
    for k,obj in [('ctrl',ctrl),('agent',agent),('persp',persp)]:
        d = np.array(obj.delta_cf_t)
        g = np.array(obj.gains)
        ax.scatter(d, g, color=C[k], s=8, alpha=0.45, label=L[k])
    style(ax, 'Delta_cf  vs  gain\n[identity condition: slope != 0 for Agent]')
    ax.set_xlabel('Delta_cf', color='#aaa', fontsize=8)
    ax.set_ylabel('gain', color='#aaa', fontsize=8)
    ax.legend(fontsize=7, facecolor='#111133', labelcolor='white')

    # Panel 4: back-action magnitude (Perspectival only)
    ax = fig.add_subplot(gs[2,0])
    ax.plot(steps, persp.backaction_magnitudes,
            color=C['persp'], lw=1.4, alpha=0.85,
            label='back-action magnitude D(rho, S(rho))')
    ax.axhline(np.mean(persp.backaction_magnitudes),
               color='white', lw=0.8, ls='--', alpha=0.5, label='mean')
    style(ax, 'Self-measurement back-action (Layer 3)\nD(rho, S(rho)) > 0  =>  perspective incomplete')
    ax.set_xlabel('step', color='#aaa', fontsize=8)
    ax.set_ylabel('trace distance', color='#aaa', fontsize=8)
    ax.legend(fontsize=7.5, facecolor='#111133', labelcolor='white')

    # Panel 5: IC score bar
    ax = fig.add_subplot(gs[2,1])
    ax.set_facecolor('#0f0f2a')
    names  = ['Controller\n(Layer 1)', 'Observer-\nAgent (L2)', 'Perspectival\n(L3)']
    scores = [ic_score(ctrl), ic_score(agent), ic_score(persp)]
    bars = ax.bar(names, scores,
                  color=[C['ctrl'], C['agent'], C['persp']],
                  width=0.5, alpha=0.85)
    ax.axhline(0, color='white', lw=0.8)
    for bar, s in zip(bars, scores):
        ax.text(bar.get_x()+bar.get_width()/2,
                s + 0.01*(1 if s>=0 else -1),
                f'{s:+.3f}', ha='center', va='bottom' if s>=0 else 'top',
                color='white', fontsize=8)
    style(ax, 'Identity condition score\ncorr(Delta_cf,  gain)\nAgent >> Controller')
    ax.set_ylabel('Pearson r', color='#aaa', fontsize=8)
    ax.set_ylim(-0.2, 1.05)

    fig.suptitle(
        '4-qubit Observer-Agent Model\n'
        'Observer as Operational Reconstruction Structure',
        color='white', fontsize=12, y=0.99)

    plt.savefig('/mnt/user-data/outputs/observer_agent.png',
                dpi=150, bbox_inches='tight', facecolor=fig.get_facecolor())
    plt.close()

# ─── Main ─────────────────────────────────────────
if __name__ == '__main__':
    N = 100
    print("Running simulation ...")
    ctrl, agent, persp = simulate(n=N, seed=7)

    print("\n=== Identity Condition Score  corr(Delta_cf, gain) ===")
    for name, obj in [('Controller          ', ctrl),
                      ('Observer-Agent      ', agent),
                      ('Perspectival Obs.   ', persp)]:
        s = ic_score(obj)
        bar = '#' * max(0, int(s * 40))
        print(f"  {name}: {s:+.4f}  |{bar}")

    print("\n=== Delta_cf statistics ===")
    for name, obj in [('Controller',ctrl),('ObserverAgent',agent),('Perspectival',persp)]:
        d = np.array(obj.delta_cf_t)
        print(f"  {name:16s}: mean={d.mean():.4f}  std={d.std():.4f}")

    print(f"\n=== Back-action (Layer 3) ===")
    ba = np.array(persp.backaction_magnitudes)
    print(f"  mean D(rho, S(rho)) = {ba.mean():.4f}  (> 0 => perspective incomplete)")

    plot(ctrl, agent, persp, N)
    print("\nDone. -> observer_agent.png")
