후라이

[Gold-3] 11049번 | 동적계획법(DP) | 자바(Java) 본문

백준/Gold

[Gold-3] 11049번 | 동적계획법(DP) | 자바(Java)

힐안 2024. 11. 12. 18:50

https://www.acmicpc.net/problem/11049

 

해당 문제는 이전에 풀이했던 동적계획법 문제와 유사하되, 행렬 곱셈의 최솟값을 찾는 문제이다.

이 문제는 행렬 곱셈의 최소 연산 횟수를 구하기 위해 DP를 사용해야 한다.

머릿속에서 개념이 정확히 안 잡힌 탓인지 DP는 볼 때마다 헷갈리고 모르겠다...나만 그러나..나만 그러겠지

 

문제 분석

 

1. 주어진 행렬들은 순서를 바꿀 수 없기 때문에, 순서를 유지하되 곱하는 방식만 고려해야 한다.

2. DP 배열을 정의하여 최소 연산 횟수를 계산해 나간다.

 

DP 접근 방식

(1) DP 배열 정의 : dp[i][j]는 i번째 행렬부터 j번째 행렬까지 곱하는 최소 연산 횟수를 저장한다.

(2) 구간 나누기 : 이 문제의 핵심은 구간을 어떻게 나누느냐이다. 예를들어, (A * B * C)를 계산할 때는 k를 이용해 구간을 나누어 (A * (B * C))나 ((A * B) * C) 중 최소 연산 횟수를 선택한다.

 

[i, j]의 최소 연산 횟수를 찾으려면 k를 기준으로 나누어서 dp[i][k] + dp[k+1][j] + (i번째 행렬의 행 x k번째 행렬의 열 x j번째 행렬의 열)을 계산한다.

이때, 마지막 항 (i번째 행렬의 행 x k번째 행렬의 열 x j번째 행렬의 열)은 k 위치에서 두 구간을 합칠 때의 곱셈 연산을 나타낸다.

 

우선, 우리는 행렬 곱셈의 크기 정보를 저장해야 한다.

int N = Integer.parseInt(br.readLine());

        int [] dims = new int[N+1];

        for(int i=0;i<N;i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int r = Integer.parseInt(st.nextToken());
            int c = Integer.parseInt(st.nextToken());
            dims[i] = r;
            if(i==N-1) dims[N]=c;
        }
        int [][] dp = new int[N][N];

 

입력 예시에서 행렬 A : 5x3, B : 3x2, C : 2x6였다.

이 행렬 곱셈 크기 들을 저장할 때 {5, 2, 3, 6}의 배열로 저장하면

 

  • 첫 번째 행렬 A는 5 x 3으로, dims[0] x dims[1]로 표현
  • 두 번째 행렬 B는 3 x 2로, dims[1] x dims[2]로 표현
  • 세 번째 행렬 C는 2 x 6로, dims[2] x dims[3]으로 표현

이렇게 배열에서 인덱스로 표현할 수 있다. 그래서 i부터 j까지의 구간을 다루며 쉽게 곱셈 비용을 계산할 수 있다.

그럼 dp[i][j]에서도 결국 이 dims배열의 인덱스를 통해 i번째 부터 j번째 행렬까지의 곱을 쉽게 구할 수 있다.

 

이제 전체적인 코드를 짜기 전, 예시로 설명을 하면

위처럼 A,B,C의 행렬이 주어졌을 때, 이들을 곱할 순서를 최적화해야 한다.

 

1. dp[0][2]를 구할 때

i=0, j=2일 때, 우리는 A * B * C를 어떻게 곱할지 순서를 정해야 한다. 이때 가능한 분할 지점 k는

  • 첫 번째 분할 (k=0): A * (B * C)
    먼저 B*C를 계산한 후 A와 곱한 값이다.
    이때 dp[1][2]는 B*C의 곱셈 횟수이며, 그 후 A와 결과를 곱하면 된다.
  • 두 번째 분할 (k=1) : (A * B) * C
    먼저 A*B를 계산하고, 그 결과를 C와 곱한다. 이때 dp[0][1]은 A*B의 곱셈 횟수이며, 그 후 결과와 C를 곱하면 된다

2. dp[i][j] 계산

dp[i][j]는 i부터 j까지 곱하는 모든 경우를 고려하여, 가장 최소의 연산 횟수를 선택하는 방식이다.

예를 들어 dp[0][2]를 계산할 때

* k=0일 경우 dp[0][0] + dp[1][2] + dims[0] * dims[1] * dims[3]의 연산 횟수

* k=1일 경우 dp[0][1] + dp[2][2] + dims[0] * dims[2] * dims[3]의 연산 횟수

이 두 횟수들을 비교해서 더 작은 값을 선택하는 것이다.

 

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());

        // 행렬의 크기 저장
        int[] dims = new int[N + 1]; // N개의 행렬이므로 크기는 N+1
        for (int i = 0; i < N; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int r = Integer.parseInt(st.nextToken()); // 행렬의 행
            int c = Integer.parseInt(st.nextToken()); // 행렬의 열
            dims[i] = r;
            if (i == N - 1) dims[N] = c; // 마지막 열 추가
        }

        // DP 테이블 초기화
        int[][] dp = new int[N][N];
        
        // DP 점화식 계산
        for (int len = 1; len < N; len++) { // 부분 문제의 길이
            for (int i = 0; i + len < N; i++) {
                int j = i + len;
                dp[i][j] = Integer.MAX_VALUE;

                for (int k = i; k < j; k++) {
                    int cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                    dp[i][j] = Math.min(dp[i][j], cost);
                }
            }
        }

        // 최종 결과 출력
        System.out.println(dp[0][N - 1]);
    }
}

 

dp[i][j]의 의미

dp[i][j]는 i부터 j까지의 행렬을 곱할 때 필요한 최소 연산 횟수를 의미한다.

  • 예를 들어, dp[0][2]는 A * B * C를 곱할 때의 최소 연산 횟수를 의미하는데,
    dp[0][1]은 A * B를 곱할 때의 최소 연산 횟수 / dp[1][2]는 B * C를 곱할 때의 최소 연산 횟수이다.

1. 왜 dp[i][j]가 필요할까?

행렬 곱셈의 문제는 순서대로 연산을 진행할 때 연산 횟수가 최소가 되도록 순서를 결정해야 하는 문제이다. 이 문제를 해결하기 위해서는, 각 구간에 대해 어디서 분할할지를 결정해야 한다.

  • (A * B) * C: 먼저 A * B를 곱한 뒤, 그 결과와 C를 곱하는 방식.
  • A * (B * C): 먼저 B * C를 곱한 뒤, 그 결과와 A를 곱하는 방식.

이 두 가지 방법을 비교하여 더 적은 연산 횟수를 선택한다. 그래서 "구간마다 최소 연산 횟수를 저장" 하는 것이 필요하고, 이를 위해 2차원 배열을 사용한다. (두 지점 사이의 구간을 다루기 위해서)

 

2. dp 배열을 사용하는 방식

  • dp[i][j]는 i부터 j까지의 행렬을 곱할 때의 최소 연산 횟수를 구하는데 사용
  • 처음에는 dp[i][j]는 모두 Integer.MAX_VALUE로 초기화되어 있고, 나중에 최소 연산 횟수를 계산하면서 이 값을 갱신
  • k는 i부터 j까지의 구간에서, 중간 분할 지점을 결정하는 역할을 하며, 모든 가능한 k에 대해 최소 연산 횟수를 비교