在線性代數(shù)中,矩陣是可以相乘的,在pytorch中矩陣也可以相乘。今天小編就帶來一篇pytorch乘法介紹,里面介紹了pytorch中的matmul與mm和bmm的區(qū)別說明。讓我們來了解pytorch中是怎么實現(xiàn)矩陣乘法的吧。
先看下官網(wǎng)上對這三個函數(shù)的介紹。
matmul
mm
bmm
顧名思義, 就是兩個batch矩陣乘法.
結(jié)論
從官方文檔可以看出
1、mm只能進(jìn)行矩陣乘法,也就是輸入的兩個tensor維度只能是( n × m ) (n imes m)(n×m)和( m × p ) (m imes p)(m×p)
2、bmm是兩個三維張量相乘, 兩個輸入tensor維度是( b × n × m ) (b imes n imes m)(b×n×m)和( b × m × p ) (b imes m imes p)(b×m×p), 第一維b代表batch size,輸出為( b × n × p ) (b imes n imes p)(b×n×p)
3、matmul可以進(jìn)行張量乘法, 輸入可以是高維.
補充:torch中的幾種乘法。torch.mm, torch.mul, torch.matmul
一、點乘
點乘都是broadcast的,可以用torch.mul(a, b)實現(xiàn),也可以直接用*實現(xiàn)。
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
[2.],
[3.]])
>>> torch.mul(a, b)
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
當(dāng)a, b維度不一致時,會自動填充到相同維度相點乘。
二、矩陣乘
矩陣相乘有torch.mm和torch.matmul兩個函數(shù)。其中前一個是針對二維矩陣,后一個是高維。當(dāng)torch.mm用于大于二維時將報錯。
>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
[4., 4.],
[4., 4.]])
>>> a = torch.ones(3,4)
>>> b = torch.ones(5,4,2)
>>> torch.matmul(a, b).shape
torch.Size([5, 3, 2])
>>> a = torch.ones(5,4,2)
>>> b = torch.ones(5,2,3)
>>> torch.matmul(a, b).shape
torch.Size([5, 4, 3])
>>> a = torch.ones(5,4,2)
>>> b = torch.ones(5,2,3)
>>> torch.matmul(b, a).shape
報錯。
以上就是pytorch乘法介紹的全部內(nèi)容,希望能給大家一個參考,也希望大家多多支持W3Cschool。如有錯誤或未考慮完全的地方,望不吝賜教。