Table of Contents
torch.mul vs torch.mm
- torch.mul(a, b) performs an element-wise multiplication of matrices a and b. For example, if the dimensions of a are (1, 2) and the dimensions of b are also (1, 2), the result will still be a matrix with dimensions (1, 2).
- torch.mm(a, b) performs a matrix multiplication of matrices a and b.
For example, if the dimensions of a are (1, 2) and the dimensions of b
are (2, 3), the result will be a matrix with dimensions (1, 3).
1
2
3
4
5
6
7
8import torch
a = torch.rand(1, 2)
b = torch.rand(1, 2)
print(torch.mul(a, b)) # return 1*2 tensor
b = torch.rand(2, 3)
print(torch.mm(a, c)) # return 1*3 tensor
torch.matmul vs torch.bmm
torch.bmm
batch matrix multiplication with strictions on the shape of matrix
1 | a = torch.rand(b, 5, 1) |
torch.matmul(input, other): more flexible
- (same input&other:1D) -> dot product
- (input: 2D, other: 2D) -> matrix multiplication
- (input&other:3D/4D) -> can do bmm or even more flexible bmm
1 | a = torch.rand(b, l, 5, 1) |
- (input: 2D, other:1D): can broadcast other, then dot product over
1 | a = torch.rand(4,3) |
(input: 1D, other:2D): can add dimension to input, matrix multiplication then remove added dimension
1
2
3
4
5
6a = torch.rand(4)
b = torch.rand(4,3)
# -> (1,4) @ (4,3)
# -> (1,3) -> remove added dimension
# -> (3)
print(torch.matmul(a,b)) # return (3) tensor(input: 1D, other:3D or more): a batched matrix multiply
1
2
3
4
5
6
7a = torch.rand(3)
b = torch.rand(2,3,4)
# -> (3) -> (2,1,3)
# -> (2,1,3) * (2,3,4)
# -> (2,1,4) -> remove added dimension
# -> (2,4)
print(torch.matmul(a,b)) # 返回 (2,4) 的tensor(input: 3D or more, other:1D): dot product on given dimension
1
2
3
4
5
6a = torch.rand(2,3,4)
b = torch.rand(4)
# -> (4) * (2,3,4)
# -> (2,3,4) * (2,3,4) -> dot product on (3,4)
# -> (2,3)
print(torch.matmul(a,b)) # 返回 (2,4) 的tensor(input: 4D or more, other:3D): if broadcast possible then batch matrix multiply
1
2
3
4
5
6a = torch.rand(10,1,2,4)
b = torch.rand(2,4,5)
# -> (10,1,2,4) -> (10,2,2,4) # broadcast
# -> (10,2,2,4) * (10,2,4,5) mm on 3rd and 4th dim
# -> (10,2,2,5)
print(torch.matmul(a,b)) # 返回 (10,2,2,5) 的tensoralso check the following documentation about matmul and bmm
torch squeeze & unsqueeze
torch.squeeze(input, dim=None)
Returns a tensor with all the dimensions of input of size 1 removed.
- if input is of shape (A x 1 x B x C x 1 x D), then the output tensor will be of shape (A x B x C x D)
- When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: (A×1×B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (AxB)
1 | import torch |
torch cat vs stack
- torch.cat: Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
1 | a = torch.rand(3,4) |
- torch.stack: Concatenates sequence of tensors along a new dimension.
1 | a = torch.rand(3,4) |