关于 numpy 中的 axis 和 pytorch 中的 dim

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

array = np.random.randint(low=0, high=6, size=(3, 4))
print(array.shape, array)
# (3, 4) [[5 3 0 4]
# [4 3 1 0]
# [2 0 3 2]]
array0 = array.sum(axis=0, keepdims=True)
print(array0.shape, array0)
# (1, 4) [[11 6 4 6]]
array1 = array.sum(axis=1, keepdims=True)
print(array1.shape, array1)
# (3, 1) [[12]
# [ 8]
# [ 7]]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

tensor = torch.randint(low=0, high=6, size=(3, 4))
print(tensor.shape, tensor)
# torch.Size([3, 4]) tensor([[1, 4, 3, 2],
# [5, 4, 2, 1],
# [3, 1, 1, 1]])
tensor0 = tensor.sum(dim=0, keepdims=True)
print(tensor0.shape, tensor0)
# torch.Size([1, 4]) tensor([[9, 9, 6, 4]])
tensor1 = tensor.sum(dim=1, keepdims=True)
print(tensor1.shape, tensor1)
# torch.Size([3, 1]) tensor([[10],
# [12],
# [ 6]])