2026年1月12日月曜日

Synthetic Control

Synthetic Control は、ベイズ統計を用いた因果推論です。機械学習ではなく、MCMCです。


1. モデル設定(合成コントロールの基本形

ある「介入クラス」A について、時点 t の平均テスト点数を y_A,t とする。また、介入を受けていない対照クラスを j = 1,…,J とし、それぞれの平均点を y_j,t と書く。
合成コントロールの考え方は「A クラスは、対照クラスの“重み付き平均”で近似できる」とみなすことである。そこで、時点 t における A クラスの期待される点数(=介入がなかったときの基準)を

μ_t = w_1 y_1,t + … + w_J y_J,t = Σ_j w_j y_j,t

と表す。w_j は 0 以上で和が 1 になるような重みとする。
実際の点数 y_A,t は、体調・偶然・測定誤差などにより μ_t からずれる。このズレを「平均 0、標準偏差 σ の正規分布に従う誤差」と仮定して

y_A,t = μ_t + 誤差_t
誤差_t ~ Normal(0, σ)

とおく。確率変数として書けば

y_A,t|w, σ ~ Normal(μ_t, σ)

である。


2. ベイズの定理と記号

推定したいパラメータを

θ = (w_1,…,w_J, σ)

とまとめて書く。ベイズの定理は以下となる。

p(θ|データ) = p(データ|θ) p(θ) / p(データ)

p(θ|データ):事後分布(データを見た後に、θ がどの程度になりそうか)
p(データ|θ):尤度(θ が真なら、このデータはどれくらい出やすいか)
p(θ):事前分布(データを見る前に、θ がどのような値を取りやすいと考えるか)
p(データ):正規化定数(全部の θ について積分した値)

実際の計算では p(データ) を明示的に計算する必要はなく、

p(θ|データ) ∝ p(データ|θ) p(θ)

という比例関係として扱うことが多い。


3. 事前分布 p(θ)

重みベクトル w = (w_1,…,w_J) について

w ~ Dirichlet(α_1,…,α_J)

とおく。誤差のばらつき σ については

σ ~ HalfNormal(τ)

とする。これらをまとめると、事前分布は

p(θ) = p(w_1,…,w_J, σ)
= Dirichlet(w | α_1,…,α_J) × HalfNormal(σ | τ)

となる。ここでは w と σ を互いに独立と仮定している。


4. 尤度 p(データ|θ)

合成コントロールの重み w は、「介入前」で学習する。介入前の時点を t = 1,…,T0 とし、A クラスの平均点:y_A,t、各対照クラスの平均点:y_j,t が観測されているとする。モデルでは、各 t について

y_A,t|θ ~ Normal(μ_t, σ)
μ_t = Σ_j w_j y_j,t

であるから、ある θ が与えられたとき、「その θ のもとで、実際に観測された系列 {y_A,1,…,y_A,T0} がどれくらい出やすいか」を表すのが尤度である。
時点 t ごとに独立と仮定すると、介入前データ全体に対する尤度は

p(データ|θ) = Π_{t=1}^{T0} p(y_A,t|θ)
= Π_{t=1}^{T0} Normal(y_A,t ; μ_t, σ)

となる。ここで μ_t は w と対照クラスの観測値 {y_j,t} から一意に決まる。
・μ_t に近い y_A,t が多ければ、この θ の尤度は大きくなる
・μ_t から遠く離れた y_A,t が多ければ、尤度は小さくなる
という直感と対応している。

5. 事後分布 p(θ|データ)

ベイズの定理より、介入前データを見た後の θ の分布(事後分布)は

p(θ | データ) ∝ p(データ | θ) p(θ)
= {Π_{t=1}^{T0} Normal(y_A,t ; μ_t, σ)}
 × Dirichlet(w | α_1,…,α_J) × HalfNormal(σ | τ)

で与えられる。これは介入前の A クラスの点数系列をどれだけよく再現するか(尤度)、もともと重み w や σ をどう想定していたか(事前)の両方をバランスさせた結果である。
この事後分布を式のまま解析的に求めることは難しいため、MCMC(たとえば NUTS というアルゴリズム)を用いて

θ^(1), θ^(2), …, θ^(S) = (w_1^(s),…,w_J^(s), σ^(s)), s=1,…,S

という多数のサンプルを生成し、p(θ|データ) を近似する。それぞれの θ^(s) が「あり得る合成コントロールの重みと誤差の大きさ」を表しており、全体として「どのような w がどの程度ありそうか」を数値的に把握できる。


6. Synthetic Control による介入効果推定

6.1 介入「前」のデータから事後分布 p(θ|データ) を推定
上で述べたように、介入前 t=1,…,T0 のデータを用いて事後分布 p(θ | データ) を求める。

6.2 介入「後」のシンセティック点数 μ_t^(s) を計算
介入後 t = T0+1,…,T の期間に注目する。この期間についても、各対照クラス j の平均点 y_j,t が観測されているとする。それぞれの事後サンプル θ^(s) に対して、時点 t ごとに

μ_post,t^(s) = Σ_j w_j^(s) y_j,t

を計算する。これは「サンプル s に対応する合成コントロールの重みを用いて、介入後の A クラスの“カウンターファクチュアル(介入がなかった場合の想定点数)”を計算したもの」である。したがって、{μ_post,t^(s)}_{s=1,…,S} は「介入がなかったとしたとき、時点 t における A クラスの点数はどのくらいになり得るか」を表す事後予測分布となる。

6.3 介入効果の事後分布を得る
実際に観測された介入後の A クラスの平均点を y_A,t(t > T0)とすると、サンプル s における時点 t の介入効果は

effect_t^(s) = y_A,t − μ_post,t^(s)

と定義できる。これは
・正なら「実際の点数の方が高い(介入がプラス効果)」
・負なら「実際の点数の方が低い(介入がマイナス効果)」
ことを意味する。

時点 t ごとに {effect_t^(s)} の平均
 → その時点の平均介入効果 2.5%点と 97.5%点
 → 95% 信用区間(事後分布の幅)を
求めれば、「いつ」「どれくらい」介入効果が表れたかをベイズ的に評価できる。
さらに、介入後の全期間について effect_t^(s) を平均すれば
・介入後期間全体の平均効果
・その信頼区間(信用区間)
も同じ要領で推定できる。

このようにして、従来の Synthetic Control(複数の対照ユニットの重み付き平均)を「事前分布」「尤度」「事後分布」の枠組みの中にきちんと位置づけ、ベイズ的な不確実性評価と一体化した形で介入効果を推定する。


7. Pythonコード(AI提案)

import pandas as pd
import numpy as np
import pymc as pm
import arviz as az

# --- データ準備 ---
# yA:  shape=(T,)       介入クラスA の時系列
# Y_donors: shape=(T,J) 対照クラス群の時系列
# T0: 介入前の最後の時点(0始まりのインデックス)

T, J = Y_donors.shape
yA_pre       = yA[: T0+1]            # 介入前データ
Y_donor_pre  = Y_donors[: T0+1, :]
yA_post      = yA[T0+1 :]            # 介入後データ
Y_donor_post = Y_donors[T0+1 :, :]

# --- 6.1 介入前データでモデルフィッティング ---
with pm.Model() as sc_model:
    # 事前分布
    w     = pm.Dirichlet("w",     a=np.ones(J))   # w_j ≥ 0, Σw_j=1
    sigma = pm.HalfNormal("sigma", sigma=1.0)     # σ>0

    # 合成コントロールの期待値 μ_t
    mu_lin = pm.math.dot(Y_donor_pre, w)
    mu_pre = pm.Deterministic("mu_pre", mu_lin)

    # 尤度
    pm.Normal("y_obs", mu=mu_pre, sigma=sigma,
              observed=yA_pre)

    # サンプリング
    trace = pm.sample(
        draws=1000,          # 事後サンプル 1000
        tune=1000,           # チューニング 1000
        chains=4,            # チェーン数
        cores=4,             # 並列数
        target_accept=0.9,   # NUTS の受諾率
        max_treedepth=12,
        random_seed=0,
        progressbar=True
    )

# --- 6.2 介入後のカウンターファクチュアルを計算 ---
# ポスターサンプルから w を (S,J) にフラット化
w_samples = trace.posterior["w"] \
    .stack(sample=("chain", "draw")) \
    .values  # shape = (S, J)

# μ_post[s,t] = Σ_j w_samples[s,j] * Y_donor_post[t,j]
# 結果 shape = (S, T_post)
mu_post = (Y_donor_post @ w_samples.T).T

# --- 6.3 介入効果を計算 ---
# effect[s,t] = 実測 yA_post[t] − μ_post[s,t]
effect = yA_post[None, :] - mu_post  # shape = (S, T_post)

# 時点別平均効果と95%信用区間
effect_mean_t = effect.mean(axis=0)                # (T_post,)
effect_hdi_t  = az.hdi(effect, hdi_prob=0.95)      # (T_post, 2)

# 介入後期間全体の平均効果とその信用区間
avg_effect     = effect.mean()                     # scalar
avg_effect_hdi = az.hdi(effect.flatten(), hdi_prob=0.95)  # (2,)

# --- 出力例 ---
print("時点別平均効果    :", effect_mean_t)
print("時点別95%信用区間:", effect_hdi_t)
print("全期間平均効果    :", avg_effect)
print("全期間95%信用区間:", avg_effect_hdi)


0 件のコメント:

コメントを投稿