Python程式:查詢二叉樹中兩個節點之間的距離


假設我們給定一棵二叉樹,並要求找到二叉樹中兩個節點之間的距離。我們像在圖中一樣找到這兩個節點之間的邊,並返回邊數或它們之間的距離。樹的節點結構如下:

data : <integer value>
right : <pointer to another node of the tree>
left : <pointer to another node of the tree>

因此,如果輸入如下所示:

並且我們必須找到節點 2 和 8 之間的距離;則輸出將為 4。

節點 2 和 8 之間的邊為:(2, 3), (3, 5), (5, 7) 和 (7, 8)。它們之間路徑中有 4 條邊,所以距離為 4。

為了解決這個問題,我們將遵循以下步驟:

  • 定義一個函式 findLca()。它將接收根節點、p 和 q。
    • 如果根節點為空,則
      • 返回 null
    • 如果根節點的資料是 (p,q) 中的任何一個,則
      • 返回根節點
    • left := findLca(根節點的左子節點, p, q)
    • right := findLca(根節點的右子節點, p, q)
    • 如果 left 和 right 均不為空,則
      • 返回根節點
    • 返回 left 或 right
  • 定義一個函式 findDist()。它將接收根節點和資料。
    • queue := 一個新的雙端佇列
    • 在佇列的末尾插入一個新的鍵值對 (根節點, 0)
    • 當佇列不為空時,執行以下操作:
      • current := 佇列中第一個鍵值對的第一個值
      • dist := 佇列中第一個鍵值對的第二個值
      • 如果 current 的資料與資料相同,則
        • 返回 dist
      • 如果 current 的左子節點不為空,則
        • 將鍵值對 (current 的左子節點, dist+1) 新增到佇列中
      • 如果 current 的右子節點不為空,則
        • 將鍵值對 (current.right, dist+1) 新增到佇列中
  • node := findLca(root, p, q)
  • 返回 findDist(node, p) + findDist(node, q)

示例

讓我們看看下面的實現,以便更好地理解:

import collections
class TreeNode:
   def __init__(self, data, left = None, right = None):
      self.data = data
      self.left = left
      self.right = right

def insert(temp,data):
   que = []
   que.append(temp)
   while (len(que)):
      temp = que[0]
      que.pop(0)
      if (not temp.left):
         if data is not None:
            temp.left = TreeNode(data)
         else:
            temp.left = TreeNode(0)
         break
      else:
         que.append(temp.left)

      if (not temp.right):
         if data is not None:
            temp.right = TreeNode(data)
         else:
            temp.right = TreeNode(0)
         break
      else:
         que.append(temp.right)

def make_tree(elements):
   Tree = TreeNode(elements[0])
   for element in elements[1:]:
      insert(Tree, element)
   return Tree

def search_node(root, element):
   if (root == None):
      return None

   if (root.data == element):
      return root

   res1 = search_node(root.left, element)
   if res1:
      return res1

   res2 = search_node(root.right, element)
   return res2

def print_tree(root):
   if root is not None:
      print_tree(root.left)
      print(root.data, end == ', ')
      print_tree(root.right)

def findLca(root, p, q):
   if root is None:
      return None
   if root.data in (p,q):
      return root
   left = findLca(root.left, p, q)
   right = findLca(root.right, p, q)
   if left and right:
      return root
   return left or right

def findDist(root, data):
   queue = collections.deque()
   queue.append((root, 0))
   while queue:
      current, dist = queue.popleft()
      if current.data == data:
         return dist
      if current.left: queue.append((current.left, dist+1))
      if current.right: queue.append((current.right, dist+1))

def solve(root, p, q):
   node = findLca(root, p, q)
   return findDist(node, p) + findDist(node, q)

root = make_tree([5, 3, 7, 2, 4, 6, 8])
print(solve(root, 2, 8))

輸入

root = make_tree([5, 3, 7, 2, 4, 6, 8])
print(solve(root, 2, 8))

輸出

4

更新於: 2021年10月7日

625 次瀏覽

開啟你的 職業生涯

透過完成課程獲得認證

開始學習
廣告

© . All rights reserved.