CAC损失函数实战:用PyTorch实现开集识别的类锚点聚类
在工业质检、自动驾驶或者安防监控这类实际场景里,我们训练好的模型总会遇到一些“没见过”的东西。比如,一个训练来识别几种特定划痕的缺陷检测模型,产线上突然出现了一种全新的、从未标注过的裂纹。传统的分类模型,比如用交叉熵损失训练的网络,很可能会自信地把这种新裂纹错误地归入某个已知的划痕类别,导致漏检,这在安全要求极高的工业环境中是致命的。这就是开集识别要解决的核心问题:模型不仅要能正确分类已知类别,还得有能力识别并拒绝那些训练时从未见过的“未知”类别。
现有的很多开集方法,都是在模型训练好之后,再在网络的logit空间里计算距离来判断“已知”和“未知”。这背后有个假设:已知类别的样本在logit空间里会聚集成紧密的簇,而未知样本则会离这些簇比较远。但问题在于,用交叉熵损失训练出来的网络,它的logit空间可不一定这么“听话”。交叉熵只关心把概率分对,并不保证样本在特征空间里能形成紧凑、可分簇。这就好比只教学生记住每个题目的答案,却不教他们理解题目之间的本质区别,一旦遇到没讲过的新题型,就很容易套用旧答案,导致错误。
Class Anchor Clustering 损失函数就是为了从根本上解决这个问题而生的。它不再依赖事后的距离度量,而是在训练过程中就“手把手”地指导网络:让每个已知类别的样本,都紧紧地围绕着一个预先设定好的“锚点”聚集起来,同时让不同类别的锚点之间保持足够的距离。这种“类内紧,类间疏”的特性,正是开集识别梦寐以求的。今天,我们就抛开理论推导,直接深入到代码层面,看看如何在PyTorch里亲手实现CAC损失,并把它应用到小样本缺陷检测这类极具挑战性的工业场景中。
1. 理解CAC损失的核心:锚点、距离与双重约束
在动手写代码之前,我们得先搞清楚CAC损失到底在做什么。它不是一个凭空冒出来的复杂公式,而是由两个直观的几何约束巧妙组合而成。
想象一下,我们在一个高维的logit空间里(比如网络最后的全连接层输出),为每一个已知类别预先放置一个固定的坐标点,这就是锚点。一个很聪明且简单的设置是,让第 i 个类别的锚点就在第 i 个坐标轴上,距离原点为 α。用向量表示就是 c_i = α * e_i,其中 e_i 是一个只在第 i 维为1的one-hot向量。这样做的好处是,不同类别的锚点天生就是正交的,初始距离就是 √2 * α,为类间分离提供了一个很好的起点。
CAC损失由两部分组成,它们分别从不同角度约束样本在空间中的位置:
- Anchor Loss:类内紧凑性约束。它的目标非常直接:让一个样本的特征向量尽可能靠近它所属类别的锚点。计算方式就是样本logit向量
z与其真实类别锚点c_y之间的欧几里得距离:L_A = ||z - c_y||_2。这个损失项越小,意味着该类别的样本簇越紧密。 - Tuplet Loss:类间分离性约束。光聚拢还不够,还得让不同类别的簇分得开。Tuplet Loss的灵感来源于度量学习,它要求样本到其真实类别锚点的距离,要显著小于到其他所有类别锚点的距离。具体实现时,它巧妙地利用了softmin函数(可以看作是softmax的“相反”操作),将距离向量转化为一个“相似度”分布,然后通过负对数似然来最大化分离边际。其公式为:
L_T = log(1 + Σ_{j≠y} exp(d_y - d_j)),其中d_y是到真实锚点的距离,d_j是到其他锚点的距离。
最终的CAC损失就是这两项的加权和:L_CAC = L_T + λ * L_A。超参数 λ 用来平衡两项的权重。通过同时优化这两个目标,网络学到的logit空间就会呈现出清晰的、边界分明的聚类结构,为后续的未知样本拒绝打下坚实基础。
提示:这里的“锚点”在训练初期是固定的(如对角矩阵),但论文提到在训练完成后,可以用所有训练样本的logit均值来微调锚点位置,使其更贴合数据的真实分布。在小样本场景下,我们也可以在训练过程中定期执行这个更新步骤。
2. 构建支持CAC的PyTorch模型
现在,我们开始用代码将上述思想具象化。首先需要改造一个标准的分类模型,使其能够输出样本到各个锚点的距离,而不仅仅是分类logits。
假设我们有一个在ImageNet上预训练好的ResNet-18作为基础模型。我们的任务是对少量(例如5个)已知的工业缺陷类别进行小样本学习,并希望模型能识别新的缺陷类型。
import torch
import torch.nn as nn
class CACModel(nn.Module):
def __init__(self, base_model, num_known_classes, alpha=10.0):
"""
初始化CAC模型。
Args:
base_model: 预训练的基础特征提取网络(如ResNet)。
num_known_classes: 已知类别的数量。
alpha: 锚点在坐标轴上的幅度,控制锚点间的初始距离。
"""
super().__init__()
self.base_model = base_model
self.num_known_classes = num_known_classes
# 小样本场景下,可以考虑冻结基础模型的底层,防止过拟合
# for param in self.base_model.parameters():
# param.requires_grad = False
# 替换基础模型的最后一层全连接层,使其输出维度等于已知类别数
in_features = self.base_model.fc.in_features
self.base_model.fc = nn.Linear(in_features, num_known_classes)
# 初始化锚点参数:一个 num_


631

被折叠的 条评论
为什么被折叠?



