"""
Emergence Simulation (simplified & honest)
==========================================

問題: 前の量子版では Delta_cf がほぼ定数になった
     → その環境では g1 の値が適応度に影響しない

解決: Delta_cf が真に変動する環境を設計する
     - 信号/ノイズ比が時系列で変動
     - Delta_cf_proxy = その変動の観測可能な指標
     - gain = g0 + g1 * delta_cf_proxy で進化させる

示すこと:
     Delta_cf が変動する環境では、g1 > 0 が自然選択される
     (= observer-agent が evolutionary attractor)

正直な制限:
     これは量子模型の近似。
     完全な量子版には richer operator set が必要。
"""
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def make_env(n_steps=100, seed=0):
    """
    時系列: 真の状態 s_true と可観測ノイズ比を生成
    - s_true: ゆっくり変化する4次元確率ベクトル
    - noise_level: 変動 (これが Delta_cf の代理指標)
    """
    rng = np.random.default_rng(seed)
    s_true = np.zeros((n_steps, 4))
    s_true[0] = np.array([1,0,0,0], dtype=float)

    for t in range(1, n_steps):
        # 環境がゆっくり変化
        if rng.random() < 0.08:  # 8%の確率でジャンプ
            new_s = rng.dirichlet(np.ones(4))
        else:
            new_s = s_true[t-1] + rng.normal(0, 0.05, 4)
            new_s = np.abs(new_s); new_s /= new_s.sum()
        s_true[t] = new_s

    # Delta_cf proxy: 介入で未来がどれだけ分岐するかの代理指標
    # signal phase (high Delta_cf): 行動が重要
    # noise phase (low Delta_cf): ノイズ支配
    dcf = np.zeros(n_steps)
    for t in range(n_steps):
        phase = 'signal' if (t % 30) < 18 else 'noise'
        if phase == 'signal':
            # 高Delta_cf: 介入で未来が分岐する
            dcf[t] = 0.20 + 0.30 * np.abs(np.sin(t * np.pi / 15))
        else:
            # 低Delta_cf: 介入に意味がない
            dcf[t] = 0.02 + 0.05 * np.abs(np.sin(t * np.pi / 15))

    return s_true, dcf

def evaluate_individual(g0, g1, s_true, dcf, noise_amp=0.40):
    """
    個体 (g0, g1) の適応度を評価
    gain(dcf) = clip(g0 + g1 * dcf, 0.01, 0.99)
    
    適応度 = 真の状態に対する追跡精度（低誤差 = 高適応度）
    """
    rng = np.random.default_rng(hash((g0, g1)) % 2**31)
    M = np.array([0.25, 0.25, 0.25, 0.25])  # 初期 latent state
    total_err = 0.0
    n = len(s_true)

    for t in range(n):
        dc = dcf[t]
        # 観測ノイズ: Delta_cf が低い局面では観測が荒れる
        noise_scale = noise_amp * (1 - dc)
        obs = s_true[t] + rng.normal(0, noise_scale, 4)
        obs = np.abs(obs); obs /= obs.sum()

        # gain = g0 + g1 * Delta_cf
        g = np.clip(g0 + g1 * dc, 0.01, 0.99)

        # M を更新
        M = (1-g) * M + g * obs
        M = np.abs(M); M /= M.sum()

        total_err += np.sum(np.abs(M - s_true[t]))

    return -total_err / n  # 高いほど良い

def evolve_population(n_steps=120, pop_size=50, n_gen=80,
                      mutation_std=0.06, seed=42):
    rng = np.random.default_rng(seed)
    s_true, dcf = make_env(n_steps=n_steps, seed=seed+1)

    # 初期集団: g0 ∈ [0.1, 0.6], g1 ∈ [-0.5, 1.5]
    pop_g0 = rng.uniform(0.1, 0.6, pop_size)
    pop_g1 = rng.uniform(-0.5, 1.5, pop_size)

    hist_g0 = []; hist_g1 = []; hist_fit = []

    for gen in range(n_gen):
        # 評価（複数環境シードで平均）
        fits = np.array([
            np.mean([evaluate_individual(
                         pop_g0[i], pop_g1[i],
                         *make_env(n_steps, seed=seed+g*7+i%5))
                     for g in range(4)])
            for i in range(pop_size)
        ])

        hist_g0.append(pop_g0.mean())
        hist_g1.append(pop_g1.mean())
        hist_fit.append(fits.mean())

        # トーナメント選択 + 変異
        new_g0 = np.zeros(pop_size)
        new_g1 = np.zeros(pop_size)
        for i in range(pop_size):
            t = rng.choice(pop_size, size=4, replace=False)
            w = t[np.argmax(fits[t])]
            new_g0[i] = pop_g0[w] + rng.normal(0, mutation_std)
            new_g1[i] = pop_g1[w] + rng.normal(0, mutation_std * 1.8)

        pop_g0, pop_g1 = new_g0, new_g1

    return hist_g0, hist_g1, hist_fit, pop_g0, pop_g1, fits, s_true, dcf

def plot(hist_g0, hist_g1, hist_fit, pop_g0, pop_g1, fits, s_true, dcf):
    gens = np.arange(len(hist_g1))
    fig = plt.figure(figsize=(14, 9))
    fig.patch.set_facecolor('#0a0a1a')
    gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.38)

    def sty(ax, t):
        ax.set_facecolor('#0f0f2a')
        ax.set_title(t, color='white', pad=7, fontsize=9)
        ax.tick_params(colors='#888', labelsize=8)
        for sp in ax.spines.values(): sp.set_color('#333355')

    # Panel 1: g1の進化
    ax=fig.add_subplot(gs[0,:])
    ax.plot(gens, hist_g1, color='#44ffaa', lw=2.2,
            label='mean g1 (Delta_cf sensitivity)')
    ax.plot(gens, hist_g0, color='#5599ff', lw=1.5, alpha=0.7,
            label='mean g0 (base gain)')
    ax.axhline(0, color='white', lw=1.0, ls='--', alpha=0.6,
               label='g1=0 (no Delta_cf sensitivity)')
    g1_final = hist_g1[-1]
    ax.fill_between(gens,0,hist_g1,
                    where=np.array(hist_g1)>0,
                    color='#44ffaa',alpha=0.12)
    ax.annotate(f'converged g1={g1_final:.3f} > 0\n=> observer-agent naturally selected',
                xy=(gens[-1],g1_final), xytext=(gens[-1]*0.6, g1_final+0.2),
                color='#44ffaa',fontsize=8.5,
                arrowprops=dict(arrowstyle='->',color='#44ffaa',lw=1.2))
    sty(ax,'Evolution of gain sensitivity (g1)\ng1 > 0 means: larger update when Delta_cf is high (observer-agent regime)')
    ax.set_xlabel('generation',color='#aaa',fontsize=8)
    ax.set_ylabel('g1 / g0',color='#aaa',fontsize=8)
    ax.legend(fontsize=8,facecolor='#111133',labelcolor='white',loc='lower right')

    # Panel 2: 最終世代の散布図
    ax=fig.add_subplot(gs[1,0])
    sc=ax.scatter(pop_g0,pop_g1,c=fits,cmap='plasma',s=35,alpha=0.8)
    ax.axhline(0,color='white',lw=0.8,ls='--',alpha=0.5)
    plt.colorbar(sc,ax=ax,label='fitness')
    sty(ax,'Final population distribution\n(g1 > 0 zone = observer-agent region)')
    ax.set_xlabel('g0 (base gain)',color='#aaa',fontsize=8)
    ax.set_ylabel('g1 (Delta_cf sensitivity)',color='#aaa',fontsize=8)
    top_fits = np.percentile(fits,75)
    mask = fits > top_fits
    ax.scatter(pop_g0[mask],pop_g1[mask],c='white',s=80,alpha=0.5,
               marker='*',label='top 25%')
    ax.legend(fontsize=7.5,facecolor='#111133',labelcolor='white')

    # Panel 3: Delta_cf の時系列
    ax=fig.add_subplot(gs[1,1])
    steps=np.arange(len(dcf))
    ax.plot(steps,dcf,color='white',lw=1.4,alpha=0.9)
    ax.fill_between(steps,0,dcf,color='#44ffaa',alpha=0.15)
    ax.axhline(0.1,color='#ffaa33',lw=0.8,ls='--',alpha=0.6,
               label='low Delta_cf region')
    sty(ax,'Delta_cf (env. informativeness)\n'
            'alternates between signal/noise phases')
    ax.set_xlabel('step',color='#aaa',fontsize=8)
    ax.set_ylabel('Delta_cf proxy',color='#aaa',fontsize=8)
    ax.legend(fontsize=7.5,facecolor='#111133',labelcolor='white')

    fig.suptitle(
        'Selection Principle: Evolutionary emergence of Delta_cf-sensitive updating\n'
        'Environment with variable informativeness selects g1 > 0 (observer-agent)',
        color='white',fontsize=10,y=0.99)
    plt.savefig('/mnt/user-data/outputs/emergence_simulation.png',
                dpi=150,bbox_inches='tight',facecolor=fig.get_facecolor())
    plt.close()

if __name__ == '__main__':
    print("Running emergence simulation (simplified env)...")
    r = evolve_population(n_steps=120, pop_size=50, n_gen=80, seed=42)
    hist_g0,hist_g1,hist_fit,pop_g0,pop_g1,fits,s_true,dcf = r

    g1f = hist_g1[-1]; g0f = hist_g0[-1]
    print(f"\nFinal: g0={g0f:.4f}, g1={g1f:.4f}")
    print()
    if g1f > 0.05:
        print(">>> g1 > 0: Delta_cf-sensitive update naturally selected")
        print(f"    gain at dcf=0.0: {np.clip(g0f,0.01,0.99):.3f}")
        print(f"    gain at dcf=0.3: {np.clip(g0f+g1f*0.3,0.01,0.99):.3f}")
        print(f"    gain at dcf=0.5: {np.clip(g0f+g1f*0.5,0.01,0.99):.3f}")
        print("\n    -> Selection principle demonstrated in simplified environment.")
        print("    -> Full quantum version needs richer operator set (noted limitation).")
    else:
        print(">>> g1 ~ 0: neutral in this environment")

    # top 10 individuals
    top10 = np.argsort(fits)[-10:]
    print(f"\nTop 10 individuals: mean g1={pop_g1[top10].mean():.4f}")

    plot(hist_g0,hist_g1,hist_fit,pop_g0,pop_g1,fits,s_true,dcf)
    print("\nDone.")
