PyTorch 中的 CrossEntropyLoss 及其具体应用
交叉熵
交叉熵 (Cross Entropy) 是 Shannon 信息论中的一个重要概念,主要用于度量两个概率分布间的差异性信息。
给定两个概率分布 $p$ 和 $q$,$p$ 相对于 $q$ 的交叉熵定义为:
$$
H(p,q)=E_p[-\log q]=H(p)+D_{KL}(p||q),
$$
其中 $H(p)$ 是 $p$ 的熵,$D_{KL}(p||q)$ 是从 $p$ 与 $q$ 的KL散度 (也被称为 $p$ 相对于 $q$ 的相对熵)。
熵的计算公式为:
$$
H(X)=E[I(X)]=E[-\ln P(X)]=-\sum P(x_i)\log P(x_i)
$$
KL散度的计算公式为:
$$
D_{KL}(p||q)=\sum P(x_i)\log \frac{P(x_i)}{Q(x_i)}
$$
对于离散分布的 $p$ 和 $q$,也就是:
$$
\begin{aligned}
H(p,q)&=\sum P(x_i)\log \frac{P(x_i)}{Q(x_i)}-\sum P(x_i)\log P(x_i)\
&=\sum P(x_i)\log \frac{1}{Q(x_i)}\
&=-\sum\limits_{x}p(x)\log q(x)
\end{aligned}
$$
torch.nn.CrossEntropyLoss
PyTorch 里的 CrossEntropyLoss 通过结合 torch.nn.LogSoftmax
和 torch.nn.NLLLoss
实现,这里不赘述这两个类的实现原理,仅描述 toch.nn.CrossEntropyLoss
的具体内容。
方法具体为:torch.nn.CrossEntropyLoss(input, target)
。
参数:
- weight (Tensor, optional) 给每个 class 的权重,Shape 为 $C$。
Shape:
- Input:$(N,C)$ 或者 $(N,C,d_1,d_2,\dots,d_K)$ ,其中 $C$ 是 class 的数量,$N$ 是样本数量。
- Target:$(N)$,其中的数值范围在 $0\leq\text{targets}[i]\leq C-1$ 用来指示类,或者 $(N,d_1,d_2,\dots,d_K)$。
- Output:和 Target 的 Size 相同。
具体公式为:
$$
\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
= -x[class] + \log\left(\sum_j \exp(x[j])\right)
$$
这个公式的含义为:先 LogSoftmax
求得样本 $x$ 是否属于它的真实类别 $class$,如果是,则值接近1,如果不是,则值接近0 (简单的 Softmax 思想),然后所有样本的值加和或者加和求平均,求得最终的 CrossEntropyLoss。
具体应用
CrossEntropyLoss 可以被应用在很多方面,这里以 SimCLR 中的计算为例,会介绍一些代码实现的 trick。
首先 SimCLR 中的正样本对 $(i,j)$ 间的 loss 被定义为:
$$
\ell_{i,j}=-\log\frac{\exp(sim(z_i,z_j)/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(sim(z_i,z_k)/\tau)},
$$
其中 $2N$ 为样本数量,有 $N$ 个原样本,$N$ 个通过数据增广得到的新样本,对于一个样本来说,有一个它增广得到的样本,正样本数量为2,其他 $2N-2$ 个样本都是负样本。$sim(\cdot)$ 函数为相似性函数,$\tau$ 为比例参数 temperature。
总 Loss function 为:
$$
\mathcal{L}=\frac{1}{2N}\sum_{k=1}^N[\ell(2k-1,2k)+\ell(2k,2k-1)]
$$
下面来看一下代码中的实现:
1 |
|
代码中注释详细叙述了流程,总体的 Trick 概括如下:
- 通过 mask 剔除分母中 $k=i$ 的部分
- 把所有正样本的 similarity reshape 到第一列
- 最后所有 labels 为0,直接调用 CrossEntropyLoss