Pytorch Tricky API

Posted by Yanchao MURONG on 2023-03-11

Table of Contents

torch.mul vs torch.mm

  • torch.mul(a, b) performs an element-wise multiplication of matrices a and b. For example, if the dimensions of a are (1, 2) and the dimensions of b are also (1, 2), the result will still be a matrix with dimensions (1, 2).
  • torch.mm(a, b) performs a matrix multiplication of matrices a and b. For example, if the dimensions of a are (1, 2) and the dimensions of b are (2, 3), the result will be a matrix with dimensions (1, 3).
    1
    2
    3
    4
    5
    6
    7
    8
    import torch

    a = torch.rand(1, 2)
    b = torch.rand(1, 2)
    print(torch.mul(a, b)) # return 1*2 tensor

    b = torch.rand(2, 3)
    print(torch.mm(a, c)) # return 1*3 tensor

torch.matmul vs torch.bmm

torch.bmm

batch matrix multiplication with strictions on the shape of matrix

1
2
3
a = torch.rand(b, 5, 1)
b = torch.rand(b, 1, 5)
print(torch.bmm(a,b)) # return b*5*5 tensor

torch.matmul(input, other): more flexible

  • (same input&other:1D) -> dot product
  • (input: 2D, other: 2D) -> matrix multiplication
  • (input&other:3D/4D) -> can do bmm or even more flexible bmm
1
2
3
a = torch.rand(b, l, 5, 1)
b = torch.rand(b, l, 1, 5)
print(torch.matmul(a,b)) # return b*l*5*5 tensor
  • (input: 2D, other:1D): can broadcast other, then dot product over
selected dimension
1
2
3
4
5
6
a = torch.rand(4,3)
b = torch.rand(3)
# -> (4,3) @ (3,4)
# -> (4,4) -> dot product
# -> (4)
print(torch.matmul(a,b)) # return (4) tensor
  • (input: 1D, other:2D): can add dimension to input, matrix multiplication then remove added dimension

    1
    2
    3
    4
    5
    6
    a = torch.rand(4)
    b = torch.rand(4,3)
    # -> (1,4) @ (4,3)
    # -> (1,3) -> remove added dimension
    # -> (3)
    print(torch.matmul(a,b)) # return (3) tensor

  • (input: 1D, other:3D or more): a batched matrix multiply

    1
    2
    3
    4
    5
    6
    7
    a = torch.rand(3)
    b = torch.rand(2,3,4)
    # -> (3) -> (2,1,3)
    # -> (2,1,3) * (2,3,4)
    # -> (2,1,4) -> remove added dimension
    # -> (2,4)
    print(torch.matmul(a,b)) # 返回 (2,4) 的tensor

  • (input: 3D or more, other:1D): dot product on given dimension

    1
    2
    3
    4
    5
    6
    a = torch.rand(2,3,4)
    b = torch.rand(4)
    # -> (4) * (2,3,4)
    # -> (2,3,4) * (2,3,4) -> dot product on (3,4)
    # -> (2,3)
    print(torch.matmul(a,b)) # 返回 (2,4) 的tensor

  • (input: 4D or more, other:3D): if broadcast possible then batch matrix multiply

    1
    2
    3
    4
    5
    6
    a = torch.rand(10,1,2,4)
    b = torch.rand(2,4,5)
    # -> (10,1,2,4) -> (10,2,2,4) # broadcast
    # -> (10,2,2,4) * (10,2,4,5) mm on 3rd and 4th dim
    # -> (10,2,2,5)
    print(torch.matmul(a,b)) # 返回 (10,2,2,5) 的tensor

  • also check the following documentation about matmul and bmm

torch squeeze & unsqueeze

torch.squeeze(input, dim=None)

Returns a tensor with all the dimensions of input of size 1 removed.

  • if input is of shape (A x 1 x B x C x 1 x D), then the output tensor will be of shape (A x B x C x D)
  • When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: (A×1×B), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (AxB)
1
2
3
4
5
6
7
8
import torch

a = torch.rand(A, 1, B, C, 1, D)
print(torch.squeeze(a)) # 返回 (A x B x C x D)

b = torch.rand(A, 1, B)
print(torch.squeeze(b, dim=0)) # 返回 (A x 1 x B)
print(torch.squeeze(b, dim=1)) # 返回 (A x B)

torch cat vs stack

  • torch.cat: Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
1
2
3
a = torch.rand(3,4)
b = torch.rand(3,4)
print(torch.cat((a,b), dim=0)) # 返回 (6x4)
  • torch.stack: Concatenates sequence of tensors along a new dimension.
1
2
3
a = torch.rand(3,4)
b = torch.rand(3,4)
print(torch.stack((a,b), dim=0)) # 返回 (2x3x4)