1. 初识torch.matmul():你的万能矩阵乘法工具箱
如果你刚开始用PyTorch做深度学习,或者从NumPy转过来,可能会被一堆乘法函数搞晕:torch.mm()、torch.mv()、torch.bmm(),还有直接用*号做元素乘。别慌,今天我要跟你聊的torch.matmul(),可以说是PyTorch里最“聪明”、最“通用”的矩阵乘法函数,它能搞定上面提到的大部分场景,而且自带“广播”魔法,让不同形状的张量也能愉快地一起计算。
简单来说,torch.matmul()就是PyTorch为你准备的“一站式”矩阵乘法解决方案。它不像torch.mm()那样死板,只认两个二维矩阵;也不像torch.dot()那样挑剔,非要两个一维向量。matmul这个名字就来自“matrix multiply”,但它实际的能力远不止于二维矩阵。我刚开始用的时候也犯嘀咕,这么多乘法函数我该用哪个?后来项目做多了才发现,torch.matmul()几乎能满足我90%的矩阵运算需求,尤其是在处理维度不确定或者需要广播的时候,用它准没错。
它的核心逻辑是:根据输入的两个张量(Tensor)的维度,自动选择最合适的乘法规则。你可以把它想象成一个经验丰富的厨师,你给他土豆和牛肉(不同维度的张量),他能自动判断是该炒土豆丝还是炖土豆牛肉(执行点积、矩阵乘还是批量乘),最后给你端上合适的菜(输出结果)。这个“自动判断”的规则,就是理解torch.matmul()的关键,也是我们接下来要深入探讨的。我们先从最基础的场景看起,你会发现它其实很直观。
2. 从零开始:五种核心乘法规则详解
官方文档把torch.matmul()的行为分成了五大类。别被“五类”吓到,其实它们有很强的规律性。我习惯把这五类分成两个层面来理解:基础层面(一维和二维张量的各种组合)和高阶层面(涉及三维及以上的批量操作)。我们先彻底搞懂基础层面,这是理解所有复杂情况的地基。
2.1 向量点积:一维 × 一维
这是最简单的情况。当两个输入都是一维张量时,torch.matmul()做的就是标准的向量点积(也叫内积)。点积的规则是:两个向量对应位置的元素相乘,然后把所有乘积加起来,得到一个标量(零维张量)。
import torch
vec_a = torch.tensor([1, 2, 3])
vec_b = torch.tensor([4, 5, 6])
result = torch.matmul(vec_a, vec_b)
print(result) # 输出: tensor(32)
print(result.shape) # 输出: torch.Size([])
计算过程就是 (1*4) + (2*5) + (3*6) = 32。这里有个关键点:两个一维向量的长度必须相同,否则会报错。这个功能跟torch.dot()是完全一样的。在实际的深度学习里,比如计算两个特征向量的相似度(余弦相似度的分子部分),或者某些简单线性层的计算,都会用到这种点积。
2.2 标准矩阵乘法:二维 × 二维
当两个输入都是二维张量时,torch.matmul()就退化成我们线性代数里学的最标准的矩阵乘法。它的行为与torch.mm()函数一致。
matrix_a = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
matrix_b = torch.tensor([[5, 6], [7, 8]]) # 形状 (2, 2)
result = torch.matmul(matrix_a, matrix_b)
print(result)
# 输出:
# tensor([[19, 22],
# [43, 50]])
print(result.shape) # 输出: torch.Size([2, 2])
矩阵乘法的规则是“前行乘后列”。具体来说,结果矩阵中第i行第j列的元素,等于第一个矩阵第i行的所有元素,与第二个矩阵第j列的对应元素相乘再求和。这里有一个必须满足的硬性条件:第一个矩阵的列数必须等于第二个矩阵的行数。在上面的例子里,matrix_a的形状是(2,2),matrix_b的形状也是(2,2),满足a的列数(2)等于b的行数(2),所以可以相乘,得到形状为(2,2)的结果。如果形状不匹配,比如尝试用(2,3)的矩阵去乘(2,2)的矩阵,PyTorch就会抛出RuntimeError,提示你形状无法相乘。
2.3 向量与矩阵相乘:一维 × 二维
从这里开始,torch.matmul()的“


6万+

被折叠的 条评论
为什么被折叠?



