NumPy - 数组搜索



在 NumPy 中搜索数组

在 NumPy 中搜索数组是指查找满足特定条件的数组元素或检索其索引的过程。

NumPy 提供各种函数来执行搜索,即使是在大型多维数组中,它们如下所示:

  • where() 函数
  • nonzero() 函数
  • searchsorted() 函数
  • argmax() 函数
  • argmin() 函数
  • extract() 函数

使用 where() 函数

NumPy 的 where() 函数用于查找数组中满足给定条件的元素的索引。该函数也可用于根据条件替换元素。以下是语法:

np.where(condition, [x, y])

其中:

  • condition − 要检查的条件。
  • x (可选) − 在条件为真时使用的值。
  • y (可选) − 在条件为假时使用的值。

示例

在以下示例中,我们使用 where() 函数检索数组中大于“25”的元素的索引,并用“0”替换小于或等于“25”的数组元素:

import numpy as np

array = np.array([10, 20, 30, 40, 50])
indices = np.where(array > 25)
print("Indices where array elements are greater than 25:", indices)

# Replacing elements based on condition
modified_array = np.where(array > 25, array, 0)
print("Array after replacing elements <= 25 with 0:", modified_array)

以下是获得的输出:

Indices where array elements are greater than 25: (array([2, 3, 4]),)
Array after replacing elements <= 25 with 0: [ 0  0 30 40 50]

使用 nonzero() 函数

NumPy 的 nonzero() 函数用于查找数组中所有非零元素的索引。它返回一个数组元组,其中每个数组包含沿特定维度的非零元素的索引。

当您想过滤掉零元素或识别稀疏数组中重要元素的位置时,此函数很有用。以下是语法:

numpy.nonzero(a)

其中,a 是要查找非零元素索引的输入数组。

示例

在下面的示例中,我们使用 nonzero() 函数检索一维数组中非零元素的索引:

import numpy as np

array = np.array([0, 1, 2, 0, 3, 0, 4])
nonzero_indices = np.nonzero(array)
print("Indices of non-zero elements:", nonzero_indices)

这将产生以下结果:

Indices of non-zero elements: (array([1, 2, 4, 6]),)

使用 searchsorted() 函数

NumPy 的 searchsorted() 函数用于查找应插入元素以保持排序数组中顺序的索引。

此函数在需要在动态插入元素时保持排序顺序的算法中很有用。以下是语法:

np.searchsorted(sorted_array, values, side='left')

其中:

  • sorted_array − 要搜索的已排序数组。
  • values − 要插入的值。
  • side − 如果为“left”,则给出第一个合适位置的索引。如果为“right”,则给出最后一个合适位置的索引。

示例

在此示例中,我们检索在已排序数组中插入值“2”、“4”和“6”以保持顺序的索引:

import numpy as np

sorted_array = np.array([1, 3, 5, 7, 9])
values = np.array([2, 4, 6])
indices = np.searchsorted(sorted_array, values)
print("Indices where values should be inserted:", indices)

以上代码的输出如下:

Indices where values should be inserted: [1 2 3]

使用 argmax() 函数

NumPy 中的 argmax() 函数用于查找数组中沿指定轴的最大值的索引。如果未指定轴,则返回扁平化数组中最大值的索引。以下是语法:

numpy.argmax(a, axis=None, out=None)

其中:

  • a − 输入数组。
  • axis (可选) − 查找最大值的轴。如果未指定,则在执行操作之前将数组展平。
  • out (可选) − 将结果存储到的位置。如果提供,则其形状必须与预期输出相同。

示例:在二维数组中使用 argmax() 函数

在以下示例中,我们使用 argmax() 函数查找二维数组中沿指定轴的最大值的索引:

import numpy as np

array = np.array([[10, 15, 5], [7, 12, 20]])
index_of_max_along_axis = np.argmax(array, axis=1)
print("Indices of the maximum values along axis 1:", index_of_max_along_axis)

获得的输出如下所示:

Indices of the maximum values along axis 1: [1 2]

示例:在扁平化数组中使用 argmax() 函数

在这里,我们使用 argmax() 函数查找扁平化数组中最大值的索引:

import numpy as np

array = np.array([[10, 15, 5], [7, 12, 20]])
index_of_max_flattened = np.argmax(array)
print("Index of the maximum value in the flattened array:", index_of_max_flattened)

执行以上代码后,我们将获得以下输出:

Index of the maximum value in the flattened array: 5

使用 argmin() 函数

NumPy 中的 argmin() 函数用于查找数组中沿指定轴的最小值的索引。如果未指定轴,则返回扁平化数组中最小值的索引。以下是语法:

numpy.argmin(a, axis=None, out=None)

其中:

  • a − 输入数组。
  • axis (可选) − 查找最小值的轴。如果未指定,则在执行操作之前将数组展平。
  • out (可选) − 将结果存储到的位置。如果提供,则其形状必须与预期输出相同。

示例

在以下示例中,我们使用 argmin() 函数查找二维数组中沿指定轴的最小值的索引:

import numpy as np

array = np.array([[10, 15, 5], [7, 12, 2]])
index_of_min_along_axis = np.argmin(array, axis=1)
print("Indices of the minimum values along axis 1:", index_of_min_along_axis)

产生的结果如下:

Indices of the minimum values along axis 1: [2 2]

使用 extract() 函数

NumPy 中的 extract() 函数用于根据布尔条件从数组中提取元素。它返回一个一维数组,其中只包含输入数组中与布尔条件中的True值对应的元素。

与返回索引的 np.where() 函数不同,np.extract() 函数直接返回满足条件的元素。

以下是语法:

numpy.extract(condition, arr)

其中:

  • condition − 指定要提取哪些元素的布尔数组或条件。它必须与 arr 的形状相同。
  • arr − 要从中提取元素的输入数组。

示例

在下面的示例中,我们使用 np.extract() 函数过滤并返回数组中大于“25”的元素:

import numpy as np

array = np.array([10, 20, 30, 40, 50])
condition = array > 25
extracted_elements = np.extract(condition, array)
print("Elements greater than 25:", extracted_elements)

我们得到如下所示的输出:

Elements greater than 25: [30 40 50]

使用布尔索引进行搜索

NumPy 中的布尔索引用于根据特定条件搜索和过滤数组。它涉及创建一个布尔数组(或掩码),其中每个值根据是否满足条件而为TrueFalse

然后,此布尔数组用于索引到原始数组中,仅提取条件为True的那些元素。

示例

以下是如何在 NumPy 中使用布尔索引根据条件过滤元素的简单示例:

import numpy as np

array = np.array([10, 20, 30, 40, 50])
boolean_mask = array > 25
filtered_array = array[boolean_mask]
print("Filtered array (elements > 25):", filtered_array)

我们得到如下所示的输出:

Filtered array (elements > 25): [30 40 50]
广告