Home/dsa/Graphs/Minimum Height Trees

Minimum Height Trees

Master this topic with zero to advance depth.

Minimum Height Trees

A tree is an undirected graph in which any two vertices are connected by exactly one path. In other words, any connected graph without simple cycles is a tree.

For a tree of nn nodes, you can choose any node as the root. The height of the tree is the maximum distance between the root and any leaf.

Return a list of all MHTs' root labels.

Visual Representation

n = 4, edges = [[1, 0], [1, 2], [1, 3]] 0 | 1 / \n2 3 Root 1 has height 1. Other roots have height 2. Result: [1]
Medium

Examples

Input: n = 4, edges = [[1, 0], [1, 2], [1, 3]]
Output: [1]

Root 1 gives the minimum height tree.

Approach 1

Level I: Brute Force (DFS/BFS from every node)

Intuition

To find the root that minimizes height, we can treat every node as the root one-by-one, perform a BFS/DFS to find its height, and pick the minimum.

Thought Process

  1. For each node ii from 00 to n1n-1:
    • Start a BFS/DFS from ii.
    • Find the maximum distance from ii to any other node (this is the height).
  2. Record the height of each node. Find the minimum height.
  3. Return all nodes that produce this minimum height.

Why it's inefficient

For each of the NN nodes, we traverse V+EV+E elements. This leads to O(N2)O(N^2) complexity.

O(N^2) - MN traversal.💾 O(V + E) - Adjacency list.

Detailed Dry Run

n=3, edges=[[0,1],[1,2]]

  • Root 0: Height 2 (path 0-1-2)
  • Root 1: Height 1 (paths 1-0, 1-2)
  • Root 2: Height 2 (path 2-1-0) Min Height: 1. Result: [1]
java
class Solution {
    public List<Integer> findMinHeightTrees(int n, int[][] edges) {
        if (n == 1) return Collections.singletonList(0);
        List<Integer>[] adj = new List[n];
        for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
        for (int[] e : edges) { adj[e[0]].add(e[1]); adj[e[1]].add(e[0]); }
        
        Map<Integer, Integer> heights = new HashMap<>();
        int minH = n;
        for (int i = 0; i < n; i++) {
            int h = getHeight(i, adj, n);
            heights.put(i, h);
            minH = Math.min(minH, h);
        }
        List<Integer> res = new ArrayList<>();
        for (int i = 0; i < n; i++) if (heights.get(i) == minH) res.add(i);
        return res;
    }
    private int getHeight(int root, List<Integer>[] adj, int n) {
        Queue<int[]> q = new LinkedList<>();
        q.add(new int[]{root, 0});
        boolean[] vis = new boolean[n];
        vis[root] = true;
        int maxD = 0;
        while (!q.isEmpty()) {
            int[] curr = q.poll();
            maxD = Math.max(maxD, curr[1]);
            for (int v : adj[curr[0]]) {
                if (!vis[v]) { vis[v] = true; q.add(new int[]{v, curr[1] + 1}); }
            }
        }
        return maxD;
    }
}
Approach 2

Level III: Optimal (Leaf Removal - BFS)

Intuition

The center of a tree is limited to 1 or 2 nodes. If we continuously remove the 'outer layers' (leaves), we will eventually converge on the center. This is similar to Kahn's algorithm but for undirected trees.

Thought Process

  1. Use an adjacency list and calculate the degree of each node.
  2. Add all nodes with degree == 1 to a Queue (initial leaves).
  3. In a loop, while remainingNodes > 2:
    • Subtract current leaves count from remainingNodes.
    • For each leaf in the queue:
      • Pop it, find its only neighbor nbr.
      • Decrement degree[nbr].
      • If degree[nbr] becomes 1, add nbr to the next level queue.
  4. The final remaining nodes in the queue are the roots of MHTs.

Pattern: Topological Trimming

O(V + E) - Visits each node and edge once.💾 O(V + E) - Adjacency list.

Detailed Dry Run

n=4, edges=[[1,0],[1,2],[1,3]]

  1. Degrees: {0:1, 1:3, 2:1, 3:1}. Leaves: [0, 2, 3]
  2. remainingNodes = 4. 4 > 2 is true.
  3. Remove 0,2,3. Neighbors: 1's degree 3->0. But we only decrement once for each leaf.
  4. 1's degree becomes 0. Queue empty. Result: [1]
java
class Solution {
    public List<Integer> findMinHeightTrees(int n, int[][] edges) {
        if (n < 2) {
            List<Integer> res = new ArrayList<>();
            for (int i = 0; i < n; i++) res.add(i);
            return res;
        }
        List<Set<Integer>> adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new HashSet<>());
        for (int[] e : edges) {
            adj.get(e[0]).add(e[1]);
            adj.get(e[1]).add(e[0]);
        }
        List<Integer> leaves = new ArrayList<>();
        for (int i = 0; i < n; i++) if (adj.get(i).size() == 1) leaves.add(i);
        int remaining = n;
        while (remaining > 2) {
            remaining -= leaves.size();
            List<Integer> newLeaves = new ArrayList<>();
            for (int leaf : leaves) {
                int neighbor = adj.get(leaf).iterator().next();
                adj.get(neighbor).remove(leaf);
                if (adj.get(neighbor).size() == 1) newLeaves.add(neighbor);
            }
            leaves = newLeaves;
        }
        return leaves;
    }
}