图像处理中的余弦相似度:从CIFAR-10数据集看图片匹配的底层逻辑
最近在整理一个图像检索的小项目时,我重新审视了“相似度”这个看似基础却至关重要的概念。很多刚接触计算机视觉的朋友,一提到图片匹配,可能立刻想到复杂的深度学习模型,比如ResNet、ViT提取的特征。这当然没错,但在深入这些“黑盒”之前,我们不妨先回到一个更本质的问题:抛开神经网络,计算机如何“理解”两张图片是相似的?答案可能比你想象的要“数学”得多——它始于将图像视为空间中的点,并用几何关系来衡量它们的远近。今天,我们就以经典的CIFAR-10数据集为舞台,深入探讨余弦相似度如何成为衡量图片相似性的一把直尺,并亲手用代码实现从图像加载、向量化到阈值判定的完整流程。这篇文章适合那些希望夯实基础,理解算法背后直观几何意义的学习者。我们将不止步于公式,而是聚焦于如何在真实的、尺寸微小的32x32图像上,让理论落地。
1. 核心原理:为什么是余弦,而不是距离?
当我们谈论两张图片的相似度时,一个最朴素的想法是:把图片变成两组数字,然后看它们相差多少。比如计算欧几里得距离(也就是我们常说的直线距离)。这个方法直观,但在图像领域,它常常会“说谎”。
想象一下,你用手机拍了两张桌子的照片,一张光线充足,一张比较昏暗。对于计算机来说,昏暗照片的每个像素值(RGB强度)可能都比明亮照片的要低。如果直接计算对应像素的差值平方和,这个距离会非常大,尽管它们拍的是同一个物体。这里的关键在于,欧氏距离对数值的绝对大小非常敏感。
而余弦相似度巧妙地绕开了这个陷阱。它的核心思想不是比较数值的绝对差异,而是比较两个向量在方向上的“一致性”。它关注的是向量之间的夹角。
一个简单的类比:在评价两份番茄炒蛋时,欧氏距离会纠结于一份盐放了3克,另一份放了5克,从而得出差异很大的结论。而余弦相似度则更关心两份菜中,盐、糖、醋、番茄、鸡蛋这些成分的比例关系是否一致。只要比例相似,即使一份做得量大,一份量小,它们也被认为是“相似”的菜品。
数学上,对于两个向量 A 和 B,它们的余弦相似度定义为:
[ \text{cosine_similarity}(A, B) = \frac{A \cdot B}{|A| |B|} = \frac{\sum_{i=1}^{n} A_i B_i}{\sqrt{\sum_{i=1}^{n} A_i^2} \sqrt{\sum_{i=1}^{n} B_i^2}} ]
这个公式的分子是向量的点积(内积),分母是各自模长(范数)的乘积。它的值域在[-1, 1]之间:
- 1:表示两个向量方向完全相同(夹角0度)。
- 0:表示两个向量正交,毫不相关(夹角90度)。
- -1:表示两个向量方向完全相反(夹角180度)。
在图像处理中,像素值通常为非负(如0-255),因此余弦相似度结果通常在[0, 1]范围内,越接近1越相似。
那么,把图片变成向量,这个“方向”又代表了什么呢?我们可以将其理解为图像中像素强度分布的总体模式或纹理走向。两张内容相似的图片,尽管整体亮度不同(导致向量长度不同),但其像素值的相对分布模式(向量的方向)应该是接近的。这就是余弦相似度在图像比对中往往比纯距离度量更鲁棒的原因。
2. 实战准备:驯服CIFAR-10数据集
理论需要实践的检验,而CIFAR-10是一个绝佳的沙盒。这个数据集包含了10个类别的6万张32x32彩色小图片,尺寸小巧,结构清晰,非常适合进行算法原型验证。我们首先需要把它正确地加载到我们的工作环境中。
我将使用PyTorch来完成数据加载,因为它提供了非常简洁且高效的接口。如果你习惯用TensorFlow或纯NumPy,思路也是完全相通的。
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
# 定义数据预处理转换。CIFAR-10原始数据是PIL Image,我们需要将其转换为Tensor,并归一化到[0,1]范围。
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为Tensor,并自动将[0,255]缩放到[0.0,1.0]
])
# 下载并加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 加载测试集(这里我们主要用训练集做实验)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,


1031

被折叠的 条评论
为什么被折叠?



