Home/dsa/Heap / Priority Queue/Kth Smallest Element in Sorted Matrix

Kth Smallest Element in Sorted Matrix

Master this topic with zero to advance depth.

Kth Smallest Element in Sorted Matrix

Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Visual Representation

matrix = [ 1, 5, 9 ] [ 10, 11, 13 ] [ 12, 13, 15 ] k = 8 All elements sorted: [1, 5, 9, 10, 11, 12, 13, 13, 15] The 8th smallest is 13. Min-Heap approach (similar to Merge K sorted lists): Initial heap: [(1,0,0)] <- (val, row, col) Expand smallest, add right neighbor and maybe down
Medium

Examples

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Input: matrix = [[-5]], k = 1
Output: -5
Approach 1

Level I: Flatten and Sort

Intuition

The simplest approach is to collect all elements from the matrix into a flat list, sort them, and return the element at index k - 1. While not leveraging the sorted property of the matrix, this is intuitive and correct.

O(N^2 log N) where N is the matrix dimension, for sorting all N^2 elements💾 O(N^2) to store all elements in a flat array

Detailed Dry Run

matrix = [[1,5,9],[10,11,13],[12,13,15]], k=8

  1. Flatten: [1, 5, 9, 10, 11, 13, 12, 13, 15]
  2. Sort: [1, 5, 9, 10, 11, 12, 13, 13, 15]
  3. Return sorted[k-1] = sorted[7] = 13
java
import java.util.*;

public class Main {
    public static int kthSmallest(int[][] matrix, int k) {
        int n = matrix.length;
        int[] flat = new int[n * n];
        int idx = 0;
        for (int[] row : matrix) {
            for (int val : row) {
                flat[idx++] = val;
            }
        }
        Arrays.sort(flat);
        return flat[k - 1];
    }

    public static void main(String[] args) {
        int[][] m = {{1,5,9},{10,11,13},{12,13,15}};
        System.out.println(kthSmallest(m, 8)); // 13
    }
}
Approach 2

Level II: Binary Search on Value Range

Intuition

Since the matrix is sorted, any element X has a predictable number of elements smaller than or equal to it. We can binary search for the value X in the range [matrix[0][0], matrix[n-1][n-1]]. For each middle value, we count how many elements are <= mid in O(N)O(N) time using the sorted property.

O(N log(Max - Min)) where Max and Min are the matrix extremes💾 O(1)

Detailed Dry Run

matrix = [[1,5,9],[10,11,13],[12,13,15]], k=8 Range: [1, 15], Mid = 8 Count <= 8: (1, 5) from row 0, none from others. Count = 2. 2 < 8, search [9, 15]. Mid = 12. Count <= 12: (1,5,9), (10,11), (12). Count = 6. 6 < 8, search [13, 15]. Mid = 14. Count <= 14: All but 15. Count = 8. 8 == 8, result could be 14, search [13, 13]. Finally converge to 13.

java
public class Solution {
    public int kthSmallest(int[][] matrix, int k) {
        int n = matrix.length;
        int low = matrix[0][0], high = matrix[n - 1][n - 1];
        
        while (low < high) {
            int mid = low + (high - low) / 2;
            if (countLessEqual(matrix, mid) < k) low = mid + 1;
            else high = mid;
        }
        return low;
    }
    
    private int countLessEqual(int[][] matrix, int target) {
        int n = matrix.length, count = 0;
        int row = n - 1, col = 0;
        while (row >= 0 && col < n) {
            if (matrix[row][col] <= target) {
                count += row + 1;
                col++;
            } else {
                row--;
            }
        }
        return count;
    }
}
Approach 3

Level III: Min-Heap (K Merged Sorted Lists)

Intuition

The matrix is essentially N sorted rows that we want to merge. We can use the same technique as "Merge K Sorted Lists": start a Min-Heap with the first element of each row. Pop the minimum, push its right neighbor. After k pops, the last popped element is the answer. Optimized: we only need to start with the first column (N elements) and expand row by row.

O(K log N) where N is the matrix dimension. We pop K times and each heap operation is O(log N).💾 O(N) for the Min-Heap which holds at most N elements.

Detailed Dry Run

matrix = [[1,5,9],[10,11,13],[12,13,15]], k=8 Heap = [(1,r:0,c:0), (10,r:1,c:0), (12,r:2,c:0)]

  1. Pop (1). Push right (5,0,1). Heap=[(5,0,1),(10,1,0),(12,2,0)]. Count=1.
  2. Pop (5). Push (9,0,2). Count=2.
  3. Pop (9). No right (col 3 out of range). Count=3.
  4. Pop (10). Push (11,1,1). Count=4.
  5. Pop (11). Push (13,1,2). Count=5.
  6. Pop (12). Push (13,2,1). Count=6.
  7. Pop (13). Push (13,1,3) - out of range. Count=7.
  8. Pop (13). Count=8. Return 13.
java
import java.util.*;

public class Main {
    public static int kthSmallest(int[][] matrix, int k) {
        int n = matrix.length;
        // Min-Heap: {value, row, col}
        PriorityQueue<int[]> minHeap = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        
        for (int r = 0; r < n; r++) {
            minHeap.offer(new int[]{matrix[r][0], r, 0});
        }
        
        int result = 0;
        for (int i = 0; i < k; i++) {
            int[] curr = minHeap.poll();
            result = curr[0];
            int row = curr[1], col = curr[2];
            if (col + 1 < n) {
                minHeap.offer(new int[]{matrix[row][col + 1], row, col + 1});
            }
        }
        return result;
    }

    public static void main(String[] args) {
        int[][] m = {{1,5,9},{10,11,13},{12,13,15}};
        System.out.println(kthSmallest(m, 8)); // 13
    }
}