1. Input定义
对于一个batch_siz为n, embedding特征长度为m的input,记为A:
2. 计算距离矩阵
对A进行elementwise square, 得到:
然后对\(A^2\)按行做reduce sum
, 得到:
A乘上A的转置得:
由等式(3)(4)得:
等式(5)的计算结果即为我们要求的距离矩阵, 简单记为\(D\):
3. 计算Loss
为了简单说明计算过程,我们假设Input的batch_size=6,而且有三个类别,其中:
- 类别0: \(x_0, x_1\)
- 类别1: \(x_2\)
- 类别2: \(x_3, x_4, x_5\)
其距离矩阵如图1所示, 其中彩色填充部分为positive pair distances, 未填充部分为negtive pair distance:
假设我们选取\(x_0\)为anchor, \(x_1\)为positive, \(x_2\)为negtive, 则\(triplet (x_0, x_1, x_2)\)的loss为:
根据图1的第一行,我们可以计算出所有以\(x_0\)为anchor的triplet的loss.
首先选出所有与\(x_0\)同类的samples与\(x_0\)的距离, 记作向量:
然后,选出所有与\(x_0\)不同类的samples, 记作:
由等式(8)(9)得:
但是,从图1的第3行可以看出,\(D_2^{neg}=D_2[0:2] + D_2[3:6]= [d_{20}, d_{21}, d_{23}, d_{24}, d_{25}]\) 并不是连续的,
又:
由等式(11)(12)计算得到anchor为\(x_0\)的所有triplet的loss的和为:
同理,可计算出\(l(x_n)\), 最终的loss为:
4. Backward
TODO