Python程序:查找二叉树中两个节点之间的距离


假设我们给定一棵二叉树,并要求找到二叉树中两个节点之间的距离。我们像在图中一样找到这两个节点之间的边,并返回边数或它们之间的距离。树的节点结构如下:

data : <integer value> right : <pointer to another node of the tree> left : <pointer to another node of the tree>

因此,如果输入如下所示:

并且我们必须找到节点 2 和 8 之间的距离;则输出将为 4。

节点 2 和 8 之间的边为:(2, 3), (3, 5), (5, 7) 和 (7, 8)。它们之间路径中有 4 条边,所以距离为 4。

为了解决这个问题,我们将遵循以下步骤:

  • 定义一个函数 findLca()。它将接收根节点、p 和 q。
    • 如果根节点为空,则
      • 返回 null
    • 如果根节点的数据是 (p,q) 中的任何一个,则
      • 返回根节点
    • left := findLca(根节点的左子节点, p, q)
    • right := findLca(根节点的右子节点, p, q)
    • 如果 left 和 right 均不为空,则
      • 返回根节点
    • 返回 left 或 right
  • 定义一个函数 findDist()。它将接收根节点和数据。
    • queue := 一个新的双端队列
    • 在队列的末尾插入一个新的键值对 (根节点, 0)
    • 当队列不为空时,执行以下操作:
      • current := 队列中第一个键值对的第一个值
      • dist := 队列中第一个键值对的第二个值
      • 如果 current 的数据与数据相同,则
        • 返回 dist
      • 如果 current 的左子节点不为空,则
        • 将键值对 (current 的左子节点, dist+1) 添加到队列中
      • 如果 current 的右子节点不为空,则
        • 将键值对 (current.right, dist+1) 添加到队列中
  • node := findLca(root, p, q)
  • 返回 findDist(node, p) + findDist(node, q)

示例

让我们看看下面的实现,以便更好地理解:

Open Compiler
import collections class TreeNode: def __init__(self, data, left = None, right = None): self.data = data self.left = left self.right = right def insert(temp,data): que = [] que.append(temp) while (len(que)): temp = que[0] que.pop(0) if (not temp.left): if data is not None: temp.left = TreeNode(data) else: temp.left = TreeNode(0) break else: que.append(temp.left) if (not temp.right): if data is not None: temp.right = TreeNode(data) else: temp.right = TreeNode(0) break else: que.append(temp.right) def make_tree(elements): Tree = TreeNode(elements[0]) for element in elements[1:]: insert(Tree, element) return Tree def search_node(root, element): if (root == None): return None if (root.data == element): return root res1 = search_node(root.left, element) if res1: return res1 res2 = search_node(root.right, element) return res2 def print_tree(root): if root is not None: print_tree(root.left) print(root.data, end == ', ') print_tree(root.right) def findLca(root, p, q): if root is None: return None if root.data in (p,q): return root left = findLca(root.left, p, q) right = findLca(root.right, p, q) if left and right: return root return left or right def findDist(root, data): queue = collections.deque() queue.append((root, 0)) while queue: current, dist = queue.popleft() if current.data == data: return dist if current.left: queue.append((current.left, dist+1)) if current.right: queue.append((current.right, dist+1)) def solve(root, p, q): node = findLca(root, p, q) return findDist(node, p) + findDist(node, q) root = make_tree([5, 3, 7, 2, 4, 6, 8]) print(solve(root, 2, 8))

输入

root = make_tree([5, 3, 7, 2, 4, 6, 8])
print(solve(root, 2, 8))

Learn Python in-depth with real-world projects through our Python certification course. Enroll and become a certified expert to boost your career.

输出

4

更新于: 2021年10月7日

625 次浏览

开启你的 职业生涯

通过完成课程获得认证

开始学习
广告