NumPy expand_dims() 函数



NumPy 的expand_dims()函数用于向数组添加新的轴或维度。此函数接收一个输入数组和一个轴参数,该参数指定在何处插入新维度。

此函数的结果是一个添加了指定轴的新数组,这对于调整数组形状以满足某些运算或函数的要求非常有用。

例如,扩展维度通常用于通过使其与广播或矩阵运算兼容来将一维数组转换为二维数组。

语法

NumPy expand_dims() 函数的语法如下:

numpy.expand_dims(a, axis)

参数

以下是 NumPy expand_dims() 函数的参数:

  • a : 输入数组。
  • axis (int): 这是在扩展轴中放置新轴的位置。如果我们提供一个负整数,它将从最后一个轴到第一个轴进行计数。

返回值

此函数返回维度增加一的数组视图。

示例 1

以下是 NumPy expand_dims() 函数的基本示例,它通过将其转换为形状为 (1, 3) 的二维数组,在 1D 数组的开头添加一个新轴:

import numpy as np

# Original 1D array
arr = np.array([1, 2, 3])

# Expand dimensions
expanded_arr = np.expand_dims(arr, axis=0)

print("Original array:")
print(arr)
print("Shape:", arr.shape)

print("\nExpanded array:")
print(expanded_arr)
print("Shape:")
print(expanded_arr.shape)

输出

Original array:
[1 2 3]
Shape: (3,)

Expanded array:
[[1 2 3]]
Shape: 
(1, 3)

示例 2

此示例通过将其形状从 (2, 2, 2) 更改为 (2, 2, 2, 1),在 3D 数组的末尾添加一个新轴:

import numpy as np

# Original 3D array
arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

# Expand dimensions
expanded_arr = np.expand_dims(arr, axis=-1)

print("Original array:")
print(arr)
print("Shape:", arr.shape)

print("\nExpanded array:")
print(expanded_arr)
print("Shape:")
print(expanded_arr.shape)

执行上述代码后,我们将得到以下结果:

Original array:
[[[1 2]
  [3 4]]

 [[5 6]
  [7 8]]]
Shape: 
(2, 2, 2)

Expanded array:
[[[[1]
   [2]]

  [[3]
   [4]]]


 [[[5]
   [6]]

  [[7]
   [8]]]]
Shape: (2, 2, 2, 1)

示例 3

以下示例显示了如何使用expand_dims()函数通过根据需要调整其形状和维度来向数组添加新维度:

import numpy as np 

# Create a 2D array
x = np.array([[1, 2], [3, 4]])

print('Array x:')
print(x)
print('\n')

# Add a new axis at position 0
y = np.expand_dims(x, axis=0)

print('Array y with a new axis added at position 0:')
print(y)
print('\n')

# Print the shapes of x and y
print('The shape of x and y arrays:')
print(x.shape, y.shape)
print('\n')

# Add a new axis at position 1
y = np.expand_dims(x, axis=1)

print('Array y after inserting axis at position 1:')
print(y)
print('\n')

# Print the number of dimensions (ndim) for x and y
print('x.ndim and y.ndim:')
print(x.ndim, y.ndim)
print('\n')

# Print the shapes of x and y
print('x.shape and y.shape:')
print(x.shape, y.shape)

执行上述代码后,我们将得到以下结果:

Array x:
[[1 2]
 [3 4]]


Array y with a new axis added at position 0:
[[[1 2]
  [3 4]]]


The shape of x and y arrays:
(2, 2) (1, 2, 2)


Array y after inserting axis at position 1:
[[[1 2]]

 [[3 4]]]


x.ndim and y.ndim:
2 3


x.shape and y.shape:
(2, 2) (2, 1, 2)
numpy_array_manipulation.htm
广告
© . All rights reserved.