NumPy - 数组过滤



NumPy 中的数组过滤

NumPy 中的数组过滤允许您根据特定条件选择和处理数据集的子集。此过程可用于提取相关数据、执行条件运算以及分析数据集的子集。

我们可以在 NumPy 中通过创建一个布尔数组(掩码)来执行过滤,其中每个元素指示原始数组中相应的元素是否满足指定条件。然后使用此掩码索引原始数组,提取满足条件的元素。

NumPy 提供了多种通过**布尔索引**和**条件运算**过滤数组的方法。

使用布尔索引的基本过滤

布尔索引允许您根据条件过滤数组元素。通过对数组应用条件,您将获得一个布尔数组,您可以使用它来索引原始数组。

示例

在以下示例中,我们从给定数组中过滤大于值“10”的元素:

import numpy as np

# Creating an array
array = np.array([1, 5, 8, 12, 20, 3])

# Define the condition
condition = array > 10

# Apply the condition to filter the array
filtered_array = array[condition]

print("Original Array:", array)
print("Filtered Array (elements > 10):", filtered_array)

以下是获得的输出:

Original Array: [ 1  5  8 12 20  3]
Filtered Array (elements > 10): [12 20]

使用多个条件过滤

使用多个条件过滤允许您从 NumPy 数组中选择同时满足多个条件的元素。这是通过使用以下逻辑运算符组合多个布尔条件来实现的:

  • **与(&) -** 选择满足两个条件的元素。
  • **或(|) -** 选择满足至少一个条件的元素。
  • **非(~) -** 选择不满足条件的元素。

表示组合条件的生成的布尔数组然后用于索引原始数组,提取满足所有指定条件的元素。

示例

在此示例中,我们使用多个条件过滤一定范围内的元素:

import numpy as np

# Creating an array
array = np.array([1, 5, 8, 12, 20, 3])

# Define multiple conditions
condition = (array > 5) & (array < 15)

# Apply the conditions to filter the array
filtered_array = array[condition]

print("Original Array:", array)
print("Filtered Array (5 < elements < 15):", filtered_array)  

这将产生以下结果:

Original Array: [ 1  5  8 12 20  3]
Filtered Array (5 < elements < 15): [ 8 12]

使用函数过滤

使用函数进行过滤时,您通常定义一个函数,该函数以数组元素作为输入并返回一个布尔值(True 或 False),指示是否应将每个元素包含在结果中。

然后将此函数应用于数组,生成的布尔数组用于索引和过滤原始数据。

示例:使用 where() 函数过滤

在下面的示例中,我们使用 where() 函数过滤 NumPy 中的元素:

import numpy as np

# Creating an array
array = np.array([1, 5, 8, 12, 20, 3])

# Define the condition
condition = array > 10

# Filter elements
filtered_indices = np.where(condition)
filtered_array = array[filtered_indices]

print("Original Array:", array)
print("Filtered Array (elements > 10) using np.where:", filtered_array)

此函数返回条件为“True”的索引。这些索引用于提取如下所示的过滤元素:

Original Array: [ 1  5  8 12 20  3]
Filtered Array (elements > 10) using np.where: [12 20]

示例:使用自定义函数过滤

让我们来看一个使用自定义函数根据特定条件过滤数组的示例:

import numpy as np

# Create a NumPy array
array = np.array([10, 15, 20, 25, 30, 35])

# Define a custom function for filtering
def is_prime(num):
   """Return True if num is a prime number, False otherwise."""
   if num <= 1:
      return False
   for i in range(2, int(np.sqrt(num)) + 1):
      if num % i == 0:
         return False
   return True

# Apply the function to each element of the array
mask = np.array([is_prime(x) for x in array])

# Use the mask to filter the array
filtered_array = array[mask]

print("Original Array:", array)
print("Mask (prime numbers):", mask)
print("Filtered Array (prime numbers):", filtered_array)                                

获得的输出如下所示:

Original Array: [10 15 20 25 30 35]
Mask (prime numbers): [False False False False False False]
Filtered Array (prime numbers): []

多维数组中的过滤

在多维数组中,可以使用布尔索引进行过滤,类似于一维数组。但是,您需要确保过滤条件已正确应用以处理数组的维度。

以下是多维数组中过滤所涉及的步骤:

  • **定义过滤条件 -** 创建应用于数组中元素的布尔条件。这些条件可以基于值或其他标准。
  • **跨维度应用条件 -** 使用这些条件来索引和选择元素。对于多维数组,您可能需要处理特定维度的条件或跨所有维度应用条件。

示例

考虑一个 2D 数组,我们希望根据应用于特定列中元素的条件过滤行:

import numpy as np

# Create a 2D NumPy array
array = np.array([[10, 20, 30],
                  [15, 25, 35],
                  [20, 30, 40]])

# Define a condition for filtering
# Select rows where the value in the second column is greater than 25
condition = array[:, 1] > 25  

# Use the condition to filter the array
filtered_array = array[condition]

print("Original Array:\n", array)
print("Condition (values in second column > 25):", condition)
print("Filtered Array:\n", filtered_array)                               

执行上述代码后,我们将获得以下输出:

Original Array:
[[10 20 30]
 [15 25 35]
 [20 30 40]]
Condition (values in second column > 25): [False False  True]
Filtered Array:
 [[20 30 40]]
广告