Synthetic Control は、ベイズ統計を用いた因果推論です。機械学習ではなく、MCMCです。
1. モデル設定(合成コントロールの基本形)
ある「介入クラス」A について、時点 t の平均テスト点数を y_A,t とする。また、介入を受けていない対照クラスを j = 1,…,J とし、それぞれの平均点を y_j,t と書く。
合成コントロールの考え方は「A クラスは、対照クラスの“重み付き平均”で近似できる」とみなすことである。そこで、時点 t における A クラスの期待される点数(=介入がなかったときの基準)を
μ_t = w_1 y_1,t + … + w_J y_J,t = Σ_j w_j y_j,t
と表す。w_j は 0 以上で和が 1 になるような重みとする。
実際の点数 y_A,t は、体調・偶然・測定誤差などにより μ_t からずれる。このズレを「平均 0、標準偏差 σ の正規分布に従う誤差」と仮定して
y_A,t = μ_t + 誤差_t
誤差_t ~ Normal(0, σ)
とおく。確率変数として書けば
y_A,t|w, σ ~ Normal(μ_t, σ)
である。
2. ベイズの定理と記号
推定したいパラメータを
θ = (w_1,…,w_J, σ)
とまとめて書く。ベイズの定理は以下となる。
p(θ|データ) = p(データ|θ) p(θ) / p(データ)
p(θ|データ):事後分布(データを見た後に、θ がどの程度になりそうか)
p(データ|θ):尤度(θ が真なら、このデータはどれくらい出やすいか)
p(θ):事前分布(データを見る前に、θ がどのような値を取りやすいと考えるか)
p(データ):正規化定数(全部の θ について積分した値)
実際の計算では p(データ) を明示的に計算する必要はなく、
p(θ|データ) ∝ p(データ|θ) p(θ)
という比例関係として扱うことが多い。
3. 事前分布 p(θ)
重みベクトル w = (w_1,…,w_J) について
w ~ Dirichlet(α_1,…,α_J)
とおく。誤差のばらつき σ については
σ ~ HalfNormal(τ)
とする。これらをまとめると、事前分布は
p(θ) = p(w_1,…,w_J, σ)
= Dirichlet(w | α_1,…,α_J) × HalfNormal(σ | τ)
となる。ここでは w と σ を互いに独立と仮定している。
4. 尤度 p(データ|θ)
合成コントロールの重み w は、「介入前」で学習する。介入前の時点を t = 1,…,T0 とし、A クラスの平均点:y_A,t、各対照クラスの平均点:y_j,t が観測されているとする。モデルでは、各 t について
y_A,t|θ ~ Normal(μ_t, σ)
μ_t = Σ_j w_j y_j,t
であるから、ある θ が与えられたとき、「その θ のもとで、実際に観測された系列 {y_A,1,…,y_A,T0} がどれくらい出やすいか」を表すのが尤度である。
時点 t ごとに独立と仮定すると、介入前データ全体に対する尤度は
p(データ|θ) = Π_{t=1}^{T0} p(y_A,t|θ)
= Π_{t=1}^{T0} Normal(y_A,t ; μ_t, σ)
となる。ここで μ_t は w と対照クラスの観測値 {y_j,t} から一意に決まる。
・μ_t に近い y_A,t が多ければ、この θ の尤度は大きくなる
・μ_t から遠く離れた y_A,t が多ければ、尤度は小さくなる
という直感と対応している。
5. 事後分布 p(θ|データ)
ベイズの定理より、介入前データを見た後の θ の分布(事後分布)は
p(θ | データ) ∝ p(データ | θ) p(θ)
= {Π_{t=1}^{T0} Normal(y_A,t ; μ_t, σ)}
× Dirichlet(w | α_1,…,α_J) × HalfNormal(σ | τ)
で与えられる。これは介入前の A クラスの点数系列をどれだけよく再現するか(尤度)、もともと重み w や σ をどう想定していたか(事前)の両方をバランスさせた結果である。
この事後分布を式のまま解析的に求めることは難しいため、MCMC(たとえば NUTS というアルゴリズム)を用いて
θ^(1), θ^(2), …, θ^(S) = (w_1^(s),…,w_J^(s), σ^(s)), s=1,…,S
という多数のサンプルを生成し、p(θ|データ) を近似する。それぞれの θ^(s) が「あり得る合成コントロールの重みと誤差の大きさ」を表しており、全体として「どのような w がどの程度ありそうか」を数値的に把握できる。
6. Synthetic Control による介入効果推定
6.1 介入「前」のデータから事後分布 p(θ|データ) を推定
上で述べたように、介入前 t=1,…,T0 のデータを用いて事後分布 p(θ | データ) を求める。
6.2 介入「後」のシンセティック点数 μ_t^(s) を計算
介入後 t = T0+1,…,T の期間に注目する。この期間についても、各対照クラス j の平均点 y_j,t が観測されているとする。それぞれの事後サンプル θ^(s) に対して、時点 t ごとに
μ_post,t^(s) = Σ_j w_j^(s) y_j,t
を計算する。これは「サンプル s に対応する合成コントロールの重みを用いて、介入後の A クラスの“カウンターファクチュアル(介入がなかった場合の想定点数)”を計算したもの」である。したがって、{μ_post,t^(s)}_{s=1,…,S} は「介入がなかったとしたとき、時点 t における A クラスの点数はどのくらいになり得るか」を表す事後予測分布となる。
6.3 介入効果の事後分布を得る
実際に観測された介入後の A クラスの平均点を y_A,t(t > T0)とすると、サンプル s における時点 t の介入効果は
effect_t^(s) = y_A,t − μ_post,t^(s)
と定義できる。これは
・正なら「実際の点数の方が高い(介入がプラス効果)」
・負なら「実際の点数の方が低い(介入がマイナス効果)」
ことを意味する。
時点 t ごとに {effect_t^(s)} の平均
→ その時点の平均介入効果 2.5%点と 97.5%点
→ 95% 信用区間(事後分布の幅)を
求めれば、「いつ」「どれくらい」介入効果が表れたかをベイズ的に評価できる。
さらに、介入後の全期間について effect_t^(s) を平均すれば
・介入後期間全体の平均効果
・その信頼区間(信用区間)
も同じ要領で推定できる。
このようにして、従来の Synthetic Control(複数の対照ユニットの重み付き平均)を「事前分布」「尤度」「事後分布」の枠組みの中にきちんと位置づけ、ベイズ的な不確実性評価と一体化した形で介入効果を推定する。
7. Pythonコード(AI提案)
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
# --- データ準備 ---
# yA: shape=(T,) 介入クラスA の時系列
# Y_donors: shape=(T,J) 対照クラス群の時系列
# T0: 介入前の最後の時点(0始まりのインデックス)
T, J = Y_donors.shape
yA_pre = yA[: T0+1] # 介入前データ
Y_donor_pre = Y_donors[: T0+1, :]
yA_post = yA[T0+1 :] # 介入後データ
Y_donor_post = Y_donors[T0+1 :, :]
# --- 6.1 介入前データでモデルフィッティング ---
with pm.Model() as sc_model:
# 事前分布
w = pm.Dirichlet("w", a=np.ones(J)) # w_j ≥ 0, Σw_j=1
sigma = pm.HalfNormal("sigma", sigma=1.0) # σ>0
# 合成コントロールの期待値 μ_t
mu_lin = pm.math.dot(Y_donor_pre, w)
mu_pre = pm.Deterministic("mu_pre", mu_lin)
# 尤度
pm.Normal("y_obs", mu=mu_pre, sigma=sigma,
observed=yA_pre)
# サンプリング
trace = pm.sample(
draws=1000, # 事後サンプル 1000
tune=1000, # チューニング 1000
chains=4, # チェーン数
cores=4, # 並列数
target_accept=0.9, # NUTS の受諾率
max_treedepth=12,
random_seed=0,
progressbar=True
)
# --- 6.2 介入後のカウンターファクチュアルを計算 ---
# ポスターサンプルから w を (S,J) にフラット化
w_samples = trace.posterior["w"] \
.stack(sample=("chain", "draw")) \
.values # shape = (S, J)
# μ_post[s,t] = Σ_j w_samples[s,j] * Y_donor_post[t,j]
# 結果 shape = (S, T_post)
mu_post = (Y_donor_post @ w_samples.T).T
# --- 6.3 介入効果を計算 ---
# effect[s,t] = 実測 yA_post[t] − μ_post[s,t]
effect = yA_post[None, :] - mu_post # shape = (S, T_post)
# 時点別平均効果と95%信用区間
effect_mean_t = effect.mean(axis=0) # (T_post,)
effect_hdi_t = az.hdi(effect, hdi_prob=0.95) # (T_post, 2)
# 介入後期間全体の平均効果とその信用区間
avg_effect = effect.mean() # scalar
avg_effect_hdi = az.hdi(effect.flatten(), hdi_prob=0.95) # (2,)
# --- 出力例 ---
print("時点別平均効果 :", effect_mean_t)
print("時点別95%信用区間:", effect_hdi_t)
print("全期間平均効果 :", avg_effect)
print("全期間95%信用区間:", avg_effect_hdi)
0 件のコメント:
コメントを投稿