关于 numpy 中的 axis 和 pytorch 中的 dim 123456789101112131415import numpy as nparray = np.random.randint(low=0, high=6, size=(3, 4))print(array.shape, array)# (3, 4) [[5 3 0 4]# [4 3 1 0]# [2 0 3 2]]array0 = array.sum(axis=0, keepdims=True)print(array0.shape, array0)# (1, 4) [[11 6 4 6]]array1 = array.sum(axis=1, keepdims=True)print(array1.shape, array1)# (3, 1) [[12]# [ 8]# [ 7]] 123456789101112131415import torchtensor = torch.randint(low=0, high=6, size=(3, 4))print(tensor.shape, tensor)# torch.Size([3, 4]) tensor([[1, 4, 3, 2],# [5, 4, 2, 1],# [3, 1, 1, 1]])tensor0 = tensor.sum(dim=0, keepdims=True)print(tensor0.shape, tensor0)# torch.Size([1, 4]) tensor([[9, 9, 6, 4]])tensor1 = tensor.sum(dim=1, keepdims=True)print(tensor1.shape, tensor1)# torch.Size([3, 1]) tensor([[10],# [12],# [ 6]])