ARPL (Adversarial Reciprocal Point Learning) (Chen et al., 2020a; 2021)
ARPL
データ、空間
DL: labeled samples ={(x1,y1),...,(xn,yn)}
DLk∈ Sk: positive training data, category k
DL≠k∈ Okpos: negative training data, category !=k
DU∈ Oneg: potential unknown data
Sk: deep embedding space of category k
Ok: open space of category k, Okpos ∪ Okneg:
Okpos: positive open space from other known classes
Okneg: negative open space as the remaining infinite unknown space
クラスKの全サンプルにおいて特徴量空間で表現される点からからPkまでの距離は、他のクラスや未知クラスにおけるどのサンプルのPkまでの距離よりも大きくなる。
max(ζ(DL≠k∪DU,Pk))≤d,∀d∈ζ(DkL,Pk)
d(C(x),Pk) =de(C(x),pki)−dd(C(x),Pk)
de(C(x),Pk) =1/m·‖C(x)−Pk‖22
dd(C(x),Pk) =C(x)·Pk
ζ(·,·): calculates the set of distances of all samples between two sets
Pk: reciprocal point of category k, learnable parameter
m: dimension
C: deep embedding function, embedding feature
実装
class Dist(nn.Module):
def __init__(self, num_classes=10, num_centers=1, feat_dim=2, init='random'):
super(Dist, self).__init__()
self.feat_dim = feat_dim
self.num_classes = num_classes
self.num_centers = num_centers
self.centers = nn.Parameter(torch.Tensor(num_classes * num_centers, self.feat_dim))
def forward(self, features, center=None, metric='l2'):
if metric == 'l2':
f_2 = torch.sum(torch.pow(features, 2), dim=1, keepdim=True)
if center is None:
c_2 = torch.sum(torch.pow(self.centers, 2), dim=1, keepdim=True)
dist = f_2 - 2*torch.matmul(features, torch.transpose(self.centers, 1, 0)) + torch.transpose(c_2, 1, 0)
else:
c_2 = torch.sum(torch.pow(center, 2), dim=1, keepdim=True)
dist = f_2 - 2*torch.matmul(features, torch.transpose(center, 1, 0)) + torch.transpose(c_2, 1, 0)
dist = dist / float(features.shape[1])
dist = torch.reshape(dist, [-1, self.num_classes, self.num_centers])
dist = torch.mean(dist, dim=2)
Class ARPLoss(nn.CrossEntropyLoss):
def __init__(self, **options):
self.Dist = Dist(num_classes=options['num_classes'], feat_dim=options['feat_dim'])
self.points = self.Dist.centers
self.radius = nn.Parameter(torch.Tensor(1))
self.radius.data.fill_(0)
self.margin_loss = nn.MarginRankingLoss(margin=1.0)
def forward(self, x, y, labels=None):
dist_dot_p = self.Dist(x, center=self.points, metric='dot')
dist_l2_p = self.Dist(x, center=self.points)
logits = dist_l2_p - dist_dot_p
loss = F.cross_entropy(logits / self.temp, labels)
center_batch = self.points[labels, :]
_dis_known = (x - center_batch).pow(2).mean(1)}
loss_r = self.margin_loss(self.radius, _dis_known, target)
loss = loss + self.weight_pl * loss_r
softmax function
p(y=k|x,C,P) = eγ**d(C(x),Pk) / ∑i=1N{eγ**d(C(x),Pi)}
γ: hyperparameter, default = 1.0
損失
Lc(x;θ,P) =−log p(y=k|x,C,P) : pytorch の CrossEntropyLoss 実装
Lo(x;θ,Pk,Rk) =max(de(C(x),Pk)−R,0)
L:joint loss L = Lc + λ·Lo.
Lc: Classifire loss
Lo: loss, open space risk
λ: weight of the adversarial open space riskmodule, default = 0.1
θ: parameters in the convolutional layers, learnable parameters
R: parameters in the losslayers, learnable margin, learnable parameters
ARPL+CS
敵対的画像の利用
CS:Confusing Samples as DU
The discriminator is optimized to discriminate the real and generated samples:
maxD{1/n∑i=1n[logD(xi) + log(1−D(G(zi)))]}
x=G(z)
D:X →[0,1] represents the probability of sample x being from the real distribution or a fakedistribution.
The generator is optimized by:
maxG{1/n∑i=1n[logD(G(zi)) +β·H(zi,P)], (18)=(16)+(17)
β: hyperparameter, default = 0.1
H(zi,P) =−1/N∑k=1N{S(zi,Pk)·log(S(zi,Pk))}: information entropy function
すべての逆数点の距離をバランスさせ、OGに近くなるようなサンプルを生成。
生成されたサンプルが既知のサンプルの境界から遠い場合、式(16)の損失大。
生成されたサンプルが既知のクラスに近い場合、式(17)の損失大。
The classifier C is optimized by the generated confusing samples as:
minC{1/n∑i=1n[L(xi,yi)−β·H(zi,P)], (19)
Note that the known samples and generated samplesare processed independently in Eq.(19).
ABNは、異なるドメインに属する特徴に対して別々のBNを保持することにより、混合分布を分離し、紛らわしいサンプルの負の影響を効果的にブロックする。
その他
図1:既知サンプルの特徴量と異なる未知サンプルの特徴量の重なりを減らすことで、認識を向上させる。アプリケーションでは、様々な未知サンプルと既知サンプルを最適に分離するスコア閾値を選択する必要がある。残念ながら、(b) Softmax (c) Prototype Learningのいずれにおいても、そのような閾値を見つけることは難しい。より良い分離は(d) ARPLで達成可能。
図2:ほとんどの手法は、「猫」に代表的な特徴を学習することに焦点を当てている。これに対して、「猫ではない」の潜在的な特徴を利用して「猫」を識別する。⇒ Reciprocal Point の利用。PK は潜在的な未利用のクラス外空間のインスタンス化された表現。OSRにおいて「ネコとは何か」という問題を解く際の不確実性を減らすために利用する。
図3: 開集合認識のための逆説的互変異性点学習(ARPL)提案手法の概要。(a)単一クラスに対する逆相対点学習は、既知の各クラスをその逆相対点から遠ざける。(b)多クラスAdversarialFusionは、相互点によって構成される多カテゴリ境界空間間の対立を引き起こす。その結果、既知クラスは特徴空間の周辺に押しやられ、未知クラスは境界空間内に制限される. (c)Instantiated Adversarial Enhancementは、分類器の信頼性を高めるために、より有効でより多様な紛らわしいサンプルを生成する。