做分类任务时,经常会遇到CrossEntropyLoss。
比如图片分类、文本分类、情感分析,只要模型要从多个类别里选一个答案,它就很可能会出现。
它为什么这么常用?
因为分类任务里,我们关心的不只是“猜没猜对”,还关心模型对正确答案有多自信。
分类模型输出的是什么
假设我们有一个三分类任务:猫、狗、鸟。
模型最后可能会输出三个分数:
猫:2.1 狗:0.3 鸟:-1.2
这些分数还不是概率,它们通常叫 logits。
我们可以用Softmax把它们转换成概率:
猫:0.82 狗:0.15 鸟:0.03
如果真实答案是猫,那么这个预测就比较好,因为模型给猫的概率最高。
CrossEntropyLoss 关注什么
CrossEntropyLoss 会重点看正确类别的概率。
如果正确类别的概率很高,loss 就小;如果正确类别的概率很低,loss 就大。
比如真实答案是猫:
模型给猫 0.9,loss 很小
模型给猫 0.5,loss 变大
模型给猫 0.1,loss 会很大
这很符合直觉。
模型不仅要选对,还要对正确答案有足够信心。
为什么不能只看准确率
假设两个模型都预测对了。
模型 A 给正确类别的概率是0.51。
模型 B 给正确类别的概率是0.95。
从准确率看,它们都对了;但从训练角度看,模型 B 显然更好。
CrossEntropyLoss 就能区分这种差别。
它会鼓励模型把更多概率分给正确类别,而不是只要勉强猜对就行。
PyTorch 里怎么用
在 PyTorch 里,CrossEntropyLoss通常这样用:
import torch from torch import nn loss_fn = nn.CrossEntropyLoss() logits = torch.tensor([[2.1, 0.3, -1.2]]) target = torch.tensor([0]) loss = loss_fn(logits, target)
这里有一个很重要的点:
CrossEntropyLoss接收的是 logits,不需要你提前手动做 Softmax。
因为 PyTorch 的CrossEntropyLoss内部已经包含了LogSoftmax和负对数似然损失。
如果你先手动 Softmax,再传进去,反而可能造成数值问题。
标签应该长什么样
对于多分类任务,target 通常是类别索引。
比如:
0 表示猫 1 表示狗 2 表示鸟
如果一批数据有 4 个样本,标签可能是:
target = torch.tensor([0, 2, 1, 0])
这表示第 1 个和第 4 个样本是猫,第 2 个是鸟,第 3 个是狗。
小结
CrossEntropyLoss 很适合分类任务,因为它能衡量模型对正确类别的信心。
记住三个关键点:
它常用于分类任务。
PyTorch 里通常直接传 logits。
target 通常是类别索引,而不是 one-hot。
理解了这些,再看分类模型训练代码,就不会被 loss 那一行卡住。
技术图:把关键链路画清楚
可运行实验:拆开 CrossEntropyLoss 的两步计算
PyTorch 的交叉熵等价于先对 logits 做log_softmax,再取正确类别的负对数似然。下面直接验证二者结果一致。
import torch from torch import nn logits = torch.tensor([[2.0, 1.0, 0.1]]) target = torch.tensor([0]) ce = nn.CrossEntropyLoss()(logits, target) log_probs = torch.log_softmax(logits, dim=1) manual = -log_probs[0, target[0]] print("概率:", [round(v, 4) for v in torch.softmax(logits, dim=1)[0].tolist()]) print(f"CrossEntropyLoss: {ce.item():.6f}") print(f"手动计算: {manual.item():.6f}")运行结果:
概率: [0.659, 0.2424, 0.0986] CrossEntropyLoss: 0.417030 手动计算: 0.417030
正确类别概率约为 0.659,对它取负对数得到约 0.417。正确类别概率越低,损失越大。
常见误区
先 Softmax 再传给
CrossEntropyLoss。它内部已经包含相应计算,应直接传 logits。Target 应该是 one-hot。常规多分类场景下 target 是
LongTensor类别索引。
动手练习
保持正确类别不变,把第一个 logit 从 2 改成 4,验证正确类别概率上升、损失下降。
本文首发于「去你想去的地方」: CrossEntropyLoss 详解:分类任务为什么常用它 | 去你想去的地方
完整学习路线、视频版和后续更新请访问原文。