PyTorch 中的 CrossEntropyLoss 及其具体应用

交叉熵

交叉熵 (Cross Entropy) 是 Shannon 信息论中的一个重要概念,主要用于度量两个概率分布间的差异性信息。

给定两个概率分布 $p$ 和 $q$,$p$ 相对于 $q$ 的交叉熵定义为:
$$
H(p,q)=E_p[-\log q]=H(p)+D_{KL}(p||q),
$$
其中 $H(p)$ 是 $p$ 的熵,$D_{KL}(p||q)$ 是从 $p$ 与 $q$ 的KL散度 (也被称为 $p$ 相对于 $q$ 的相对熵)。

熵的计算公式为:
$$
H(X)=E[I(X)]=E[-\ln P(X)]=-\sum P(x_i)\log P(x_i)
$$
KL散度的计算公式为:
$$
D_{KL}(p||q)=\sum P(x_i)\log \frac{P(x_i)}{Q(x_i)}
$$
对于离散分布的 $p$ 和 $q$,也就是:
$$
\begin{aligned}
H(p,q)&=\sum P(x_i)\log \frac{P(x_i)}{Q(x_i)}-\sum P(x_i)\log P(x_i)\
&=\sum P(x_i)\log \frac{1}{Q(x_i)}\
&=-\sum\limits_{x}p(x)\log q(x)
\end{aligned}
$$

torch.nn.CrossEntropyLoss

PyTorch 里的 CrossEntropyLoss 通过结合 torch.nn.LogSoftmaxtorch.nn.NLLLoss 实现,这里不赘述这两个类的实现原理,仅描述 toch.nn.CrossEntropyLoss 的具体内容。

方法具体为:torch.nn.CrossEntropyLoss(input, target)

参数:

  • weight (Tensor, optional) 给每个 class 的权重,Shape 为 $C$。

Shape:

  • Input:$(N,C)$ 或者 $(N,C,d_1,d_2,\dots,d_K)$ ,其中 $C$ 是 class 的数量,$N$ 是样本数量。
  • Target:$(N)$,其中的数值范围在 $0\leq\text{targets}[i]\leq C-1$ 用来指示类,或者 $(N,d_1,d_2,\dots,d_K)$。
  • Output:和 Target 的 Size 相同。

具体公式为:
$$
\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
= -x[class] + \log\left(\sum_j \exp(x[j])\right)
$$
这个公式的含义为:先 LogSoftmax 求得样本 $x$ 是否属于它的真实类别 $class$,如果是,则值接近1,如果不是,则值接近0 (简单的 Softmax 思想),然后所有样本的值加和或者加和求平均,求得最终的 CrossEntropyLoss。

具体应用

CrossEntropyLoss 可以被应用在很多方面,这里以 SimCLR 中的计算为例,会介绍一些代码实现的 trick。

首先 SimCLR 中的正样本对 $(i,j)$ 间的 loss 被定义为:
$$
\ell_{i,j}=-\log\frac{\exp(sim(z_i,z_j)/\tau)}{\sum_{k=1}^{2N}\mathbb{1}_{[k\neq i]}\exp(sim(z_i,z_k)/\tau)},
$$
其中 $2N$ 为样本数量,有 $N$ 个原样本,$N$ 个通过数据增广得到的新样本,对于一个样本来说,有一个它增广得到的样本,正样本数量为2,其他 $2N-2$ 个样本都是负样本。$sim(\cdot)$ 函数为相似性函数,$\tau$ 为比例参数 temperature。

总 Loss function 为:
$$
\mathcal{L}=\frac{1}{2N}\sum_{k=1}^N[\ell(2k-1,2k)+\ell(2k,2k-1)]
$$
下面来看一下代码中的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# features shape:(minibatch*2, dims), (i, dims) 和 (i+minibatch, dims) 互为正样本对
def info_nce_loss(features, **kwargs):

labels = torch.cat([torch.arange(kwargs['batch_size']) for i in range(2)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(kwargs['device'])
# labels 为四个区块均为单位阵的矩阵, 以 4x4 矩阵为例, labels:
# [[1, 0, 1, 0],
# [0, 1, 0, 1],
# [1, 0, 1, 0],
# [0, 1, 0, 1]]
# 最终 labels shape: (minibatch*2, minibatch*2)

features = F.normalize(features, dim=1) # 把 每一行 normalize

similarity_matrix = torch.matmul(features, features.T) # 计算相似度矩阵

mask = torch.eye(labels.shape[0], dtype=torch.bool).to(kwargs['device'])
# 此处 mask 是和 labels 相同 shape 的单位阵
labels = labels[~mask].view(labels.shape[0], -1)
# 把 labels 对角线上的值 mask 掉, 将矩阵 reshape 成 (minibatch*2, minibatch*2-1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# 对相似度矩阵同样操作, reshape 成 (minibatch*2, minibatch*2-1)
# 这里几步操作就完成了 l(i, j) 分母上剔除 k=i 情况的操作

positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# 正样本的 sim
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
# 负样本的 sim

logits = torch.cat([positives, negatives], dim=1)
# 每个样本的第一列元素即为正样本, 要使正样本的 sim 最大
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(kwargs['device'])

logits = logits / kwargs['temperature']
# 这里除以公式里的 \tau
return logits, labels
# 返回 logits 和 labels, 直接调用 CrossEntropyLoss 即可

torch.nn.CrossEntropyLoss(logits, labels)
# logits shape:(minibatch*2,minibatch*2-1) 每一行是和其他样本的相似度,每行的第一列是和正样本的相似度
# labels 因为正样本都在第一列, 所以 labels 是全0 Tensor

代码中注释详细叙述了流程,总体的 Trick 概括如下:

  • 通过 mask 剔除分母中 $k=i$ 的部分
  • 把所有正样本的 similarity reshape 到第一列
  • 最后所有 labels 为0,直接调用 CrossEntropyLoss

PyTorch 中的 CrossEntropyLoss 及其具体应用
https://pandintelli.github.io/2022/01/20/CrossEntropyLoss/
作者
Pand
发布于
2022年1月20日
许可协议