Numpy运算中的axis参数

最近在学习Numpy时,对函数设置中axis(轴)参数有些困惑,学习了这两篇文章: Python · numpy · axisNumpy:对Axis的理解, 感觉自己弄明白了,在这里分享我的理解,欢迎交流和指正。

举例说明更明白,那就直接上例子吧。 数组示例 对于图中所示的数组,shape为(2,3,4),表示这是一个三维数组,每个维度的大小分别是2,3,4,换言之,就是在相应的维度上有2/3/4个元素。以第一维维度(对应axis=0)为例,其大小为2,也就是有两个元素,即图中的两个 3x4 数组。

axis(轴)操作是依据什么规则呢?一般来说,对于N维数据,axis的数字从0到N-1,对应于数组从外到内层(看括号方向)。 我的理解:换个角度,我们可以想象在多维空间进行操作。选取某根坐标轴(axis),相应得到多个切面,然后对这多个切面进行操作。 我们可以在每个维度取出一个“元素”来看看(对应于空间的一个“切面”): 某一维的元素示例 观察结果我们发现,某一维的元素,其大小正好是其余两维的大小。例如对于上面shape为(2,3,4)的数组b,第一维(axis=0)的元素b[1,:,:]的shape为(3,4),正是其余两维的大小。第二维(axis=1)、第三维(axis=2)的情况也是如此。

现在我们来看np.sum()操作,它是把某个方向上的元素相加,对应到空间中,相当于多个切面叠加到一起,合而为一。因此得到的结果的shape与元素一致。

np.sum()示例

np.sort()更有意思,它的结果也体现轴操作的特点。比如axis=0时,注意看是数组b中的[1,2,3,4]和[3,2,5,6]对应逐元素排序,[5,2,4,1]和[1,3,6,7]对应逐元素排序,[3,1,1,2]和[6,9,3,1]对应逐元素排序。 np.sort()示例