Home/dsa/Binary Search/Kth Smallest Element in a Sorted Matrix

Kth Smallest Element in a Sorted Matrix

Master this topic with zero to advance depth.

Kth Smallest Element in a Sorted Matrix

Given an n x n matrix where each of the rows and columns are sorted in ascending order, return the k-th smallest element.

Medium

Examples

Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Approach 1

Level I: Min-Heap (K-Way Merge)

Intuition

This is equivalent to merging NN sorted lists. Use a min-heap to pick the smallest available element kk times.

$O(K \\log N)$💾 $O(N)$

Detailed Dry Run

matrix = [[1,5,9],[10,11,13],[12,13,15]], k=8 Heap: [(1,0,0), (10,1,0), (12,2,0)] Pop (1,0,0) -> Push (5,0,1). Pop (5,0,1) -> Push (9,0,2). ... 8th pop = 13.

java
class Solution {
    public int kthSmallest(int[][] matrix, int k) {
        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        for (int i = 0; i < matrix.length; i++) pq.offer(new int[]{matrix[i][0], i, 0});
        for (int i = 0; i < k - 1; i++) {
            int[] cur = pq.poll();
            if (cur[2] + 1 < matrix.length) pq.offer(new int[]{matrix[cur[1]][cur[2] + 1], cur[1], cur[2] + 1});
        }
        return pq.poll()[0];
    }
}
Approach 2

Level III: Optimal (Binary Search on Answer)

Intuition

The value space is [matrix[0][0], matrix[n-1][n-1]]. For a value mid, we can count elements le\\le mid in O(N)O(N) by starting from the bottom-left corner.

$O(N \\cdot \\log(Max - Min))$💾 $O(1)$

Detailed Dry Run

StepLRMidCountDecision
111582L = 9
2915126L = 13
31315149R = 14
Exit----Return 13
java
class Solution {
    public int kthSmallest(int[][] matrix, int k) {
        int n = matrix.length, l = matrix[0][0], r = matrix[n - 1][n - 1];
        while (l < r) {
            int mid = l + (r - l) / 2;
            if (count(matrix, mid) < k) l = mid + 1; else r = mid;
        }
        return l;
    }
    int count(int[][] matrix, int mid) {
        int n = matrix.length, c = 0, j = n - 1;
        for (int i = 0; i < n; i++) {
            while (j >= 0 && matrix[i][j] > mid) j--;
            c += (j + 1);
        }
        return c;
    }
}