【metric learning】Dense triplet loss

1. Input定义

对于一个batch_siz为n, embedding特征长度为m的input,记为A:

2. 计算距离矩阵

对A进行elementwise square, 得到:

然后对\(A^2\)按行做reduce sum, 得到:

A乘上A的转置得:

由等式(3)(4)得:

等式(5)的计算结果即为我们要求的距离矩阵, 简单记为\(D\):

3. 计算Loss

为了简单说明计算过程,我们假设Input的batch_size=6,而且有三个类别,其中:

  • 类别0: \(x_0, x_1\)
  • 类别1: \(x_2\)
  • 类别2: \(x_3, x_4, x_5\)

其距离矩阵如图1所示, 其中彩色填充部分为positive pair distances, 未填充部分为negtive pair distance:

图片

图1

假设我们选取\(x_0\)为anchor, \(x_1\)为positive, \(x_2\)为negtive, 则\(triplet (x_0, x_1, x_2)\)的loss为:

根据图1的第一行,我们可以计算出所有以\(x_0\)为anchor的triplet的loss.

首先选出所有与\(x_0\)同类的samples与\(x_0\)的距离, 记作向量:

然后,选出所有与\(x_0\)不同类的samples, 记作:

由等式(8)(9)得:

但是,从图1的第3行可以看出,\(D_2^{neg}=D_2[0:2] + D_2[3:6]= [d_{20}, d_{21}, d_{23}, d_{24}, d_{25}]\) 并不是连续的,

又:

由等式(11)(12)计算得到anchor为\(x_0\)的所有triplet的loss的和为:

同理,可计算出\(l(x_n)\), 最终的loss为:

4. Backward

TODO

更早的文章

【数学基础】直观理解泰勒展开

1. 问题为了更好的理解分析这个世界,我们习惯把现实问题建模为数学问题。比如我们把简谐运动的某个特性模拟为正玄/余玄函数:图1-1然而,现实是非常复杂的,能用函数模拟的事物是有限的,即使能用函数模拟,大部分情况也是非常复杂的函数,当前人类的数学能力不能很完美的去处理这些复杂的数学公式。2. 局部模拟面对复杂的函数,硬上是肯定不行的,俗话说“退一步海阔天空”,于是我们想:我们可以只研究我们关心的函数特性,或者这我们只关注函数的一部分。假设,对于余玄函数:我们只关心其在(x=0)附近的性质,我...…

继续阅读