Home/dsa/Dynamic Programming/Minimum Cost to Merge Stones

Minimum Cost to Merge Stones

Master this topic with zero to advance depth.

Minimum Cost to Merge Stones

There are n piles of stones arranged in a row. The i-th pile has stones[i] stones.

A move consists of merging exactly k consecutive piles into one pile, and the cost of this move is equal to the total number of stones in these k piles.

Return the minimum cost to merge all piles of stones into one pile. If it is impossible, return -1.

Visual Representation

stones = [3, 2, 4, 1], k = 2 [3, 2, 4, 1] -> [5, 4, 1] (cost 5) [5, 4, 1] -> [5, 5] (cost 5) [5, 5] -> [10] (cost 10) Total Cost: 20 (Min)
Hard

Examples

Input: stones = [3, 2, 4, 1], k = 2
Output: 20
Input: stones = [3, 2, 4, 1], k = 3
Output: -1

After merging 3 piles, we have 2 piles left. We need exactly k=3 piles to merge.

Approach 1

Level I: Brute Force (Recursion)

Intuition

Try all possible splitting points in the current range [i, j] to merge into m piles. This is a classic exhaustive search for partition problems.

Thought Process

  1. solve(i, j, piles) returns min cost to merge stones[i...j] into piles piles.
  2. Transitions:
    • To get piles piles from [i, j], we can split into [i, k] (merged into 1 pile) and [k+1, j] (merged into piles-1 piles).
    • solve(i, j, piles) = min(solve(i, k, 1) + solve(k + 1, j, piles - 1)).
  3. Base case: solve(i, i, 1) = 0.
Exponential💾 O(N)

Detailed Dry Run

stones = [3, 2, 4], K = 3

  1. solve(0, 2, 1) = solve(0, 2, 3) + sum(3, 2, 4).
  2. solve(0, 2, 3) = split into [0,0] (1 pile) and [1,2] (2 piles).
  3. Result: 9.
java
public class Main {
    public static int solve(int[] stones, int K, int i, int j, int m) {
        if (i == j) return m == 1 ? 0 : 1000000;
        if (m == 1) return solve(stones, K, i, j, K) + sum(stones, i, j);

        int res = 1000000;
        for (int k = i; k < j; k += K - 1) {
            res = Math.min(res, solve(stones, K, i, k, 1) + solve(stones, K, k + 1, j, m - 1));
        }
        return res;
    }
    private static int sum(int[] s, int i, int j) { 
        int sum = 0; for(int k=i; k<=j; k++) sum += s[k]; return sum; 
    }

    public static void main(String[] args) {
        System.out.println(solve(new int[]{3, 2, 4, 1}, 2, 0, 3, 1)); // 20
    }
}
Approach 2

Level II: Memoization (Top-Down 3D)

Intuition

Cache the result of solve(i, j, m) — the minimum cost to merge stones[i..j] into m piles. This avoids recomputing the same interval+pile-count combinations.

Visual

solve(0,3,1) [merge ALL into 1 pile] -> solve(0,1,1) + solve(2,3,1) [split and merge] Both sub-calls cached after first computation!
O(N^3 / K)💾 O(N^2 * K)

Detailed Dry Run

stones=[3,2,4], K=3. solve(0,2,1): need K=3 piles first. solve(0,2,3)=0 (3 stones, 3 piles is base). Cost=9. Cached.

java
import java.util.Arrays;

public class Main {
    public static int mergeStones(int[] stones, int K) {
        int n = stones.length;
        if ((n-1)%(K-1)!=0) return -1;
        int[] pre = new int[n+1];
        for (int i=0;i<n;i++) pre[i+1]=pre[i]+stones[i];
        int[][][] memo = new int[n][n][K+1];
        for (int[][] a : memo) for (int[] b : a) Arrays.fill(b,-1);
        return solve(stones,K,pre,memo,0,n-1,1);
    }

    static int solve(int[] s,int K,int[] pre,int[][][] memo,int i,int j,int m) {
        if (i==j) return m==1?0:1000000;
        if (memo[i][j][m]!=-1) return memo[i][j][m];
        if (m==1) {
            int cost=solve(s,K,pre,memo,i,j,K)+(pre[j+1]-pre[i]);
            return memo[i][j][1]=cost;
        }
        int res=1000000;
        for (int k=i;k<j;k+=K-1)
            res=Math.min(res,solve(s,K,pre,memo,i,k,1)+solve(s,K,pre,memo,k+1,j,m-1));
        return memo[i][j][m]=res;
    }

    public static void main(String[] args) {
        System.out.println(mergeStones(new int[]{3,2,4,1},2)); // 20
    }
}
Approach 3

Level III: Dynamic Programming (Interval/3D)

Intuition

This is a complex interval DP problem. We want to find the min cost to merge stones in range [i, j] into p piles.

Thought Process

  1. dp[i][j][m] means the min cost to merge stones in stones[i...j] into m piles.
  2. Transitions:
    • To merge [i, j] into m piles, we can split it into [i, k] (merged into 1 pile) and [k+1, j] (merged into m-1 piles).
    • dp[i][j][m] = min(dp[i][k][1] + dp[k+1][j][m-1]) for k in [i, j-1].
    • Base Case: dp[i][i][1] = 0 (single pile is already 1 pile).
  3. When m == 1, the cost is dp[i][j][k] + sum(stones[i...j]).
O(N^3 * K)💾 O(N^2 * K)

Detailed Dry Run

This is a classical matrix chain multiplication style problem. The state dp[i][j][m] helps track number of piles remaining, which is crucial because we can only merge EXACTLY K piles at once.

java
public class Main {
    public static int mergeStones(int[] stones, int K) {
        int n = stones.length;
        if ((n - 1) % (K - 1) != 0) return -1;

        int[] prefixSum = new int[n + 1];
        for (int i = 0; i < n; i++) prefixSum[i + 1] = prefixSum[i] + stones[i];

        int[][] dp = new int[n][n];
        for (int len = K; len <= n; len++) {
            for (int i = 0; i <= n - len; i++) {
                int j = i + len - 1;
                dp[i][j] = Integer.MAX_VALUE;
                for (int m = i; m < j; m += K - 1) {
                    dp[i][j] = Math.min(dp[i][j], dp[i][m] + dp[m + 1][j]);
                }
                if ((len - 1) % (K - 1) == 0) {
                    dp[i][j] += prefixSum[j + 1] - prefixSum[i];
                }
            }
        }
        return dp[0][n - 1];
    }

    public static void main(String[] args) {
        System.out.println(mergeStones(new int[]{3, 2, 4, 1}, 2)); // 20
    }
}