[java,백준2749]피보나치 수 3 - 행렬 제곱과 DP를 활용한 최적화

행렬의 거듭제곱 성질과 DP를 활용하여 피보나치 수열의 N번째 수를 빠르게 구하는 문제입니다. 피사노 주기를 사용한 방법은 아닙니다.

피보나치 수 3(백준2749)


아래 피보나치 수를 행렬로 바꿔 나타내보면 공식을 도출할 수 있다.(자세한 도출 과정은 생략. 필요시 구글링)

Image

n=1일때 대입해보면 왼쪽처럼 {0,1},{1,1} 이 나옴을 알 수 있음.
만약 n=5면? {0,1},{1,1} 을 5번 행렬곱 하면 구해짐.
이 행렬식을 이용해서 피보나치 수 구하게 되면 정답 도출은 result[0][1] 이런식으로 fn 요소만 출력!



풀이

구현 주의사항
💡 다음 사항들을 반드시 체크해야 합니다:

  • long 자료형 사용 필수
  • 행렬 연산 적용
  • 메모이제이션으로 중복 계산 방지

이 문제는 행렬의 거듭제곱을 분할정복으로 계산하고, DP로 중복 계산을 방지하는 방식으로 구현했습니다.

핵심 구현 사항

  1. 2x2 행렬의 곱셈 함수 구현

  2. 분할정복으로 거듭제곱 최적화

    예로 200제곱x200제곱=400제곱 이라는 단순한 수학계산이지만 이 아이디어로 복잡도를 감소!

  3. HashMap을 활용한 메모이제이션 => top-bottom DP


행렬 곱셈

static Long[][] calMatrix(Long[][] m1, Long[][] m2, long n) {
    Long[][] result = new Long[2][2];
    result[0][0] = (m1[0][0]*m2[0][0] + m1[0][1]*m2[1][0]) % 1000000;
    result[0][1] = (m1[0][0]*m2[0][1] + m1[0][1]*m2[1][1]) % 1000000;
    result[1][0] = result[0][1];  // 피보나치 행렬의 특성
    result[1][1] = (m1[1][0]*m2[0][1] + m1[1][1]*m2[1][1]) % 1000000;
    map.put(n, result);
    return result;
}


분할정복 피보나치

  • 홀수의 이해를 돕자면, n=5일 경우 n/2=2 가 나옴.
    우린 제곱수를 합하면 5가 나와야하니까 3이 필요!
    그래서 n/2 에 +1 을 추가했을 뿐임.
  • 제곱수 합한다는 의미는 a^200 x a^200 = a^400 를 구할때 곱하기에서 지수는 더해진다는 수학적 특징을 말함.(지수:200+200=400)
static Long[][] fib(long n) {
    if(n == 1) return matrix;
    if(map.containsKey(n)) return map.get(n);
    
    return n % 2 == 0 ? 
        calMatrix(fib(n/2), fib(n/2), n) : 
        calMatrix(fib(n/2), fib(n/2 + 1), n);
}



전체 코드

public class 피보나치수3_2749 {
  static Long[][] matrix = // {0L, 1L}, {1L, 1L}가 할당되게 초기화 ㄱ
  static HashMap<Long, Long[][]> map = new HashMap<>(); //메모이제이션

  //행렬 연산 함수
  static Long[][] calMatrix(Long[][] m1, Long[][] m2, long n) {
    Long[][] result = new Long[2][2];
    result[0][0] = (m1[0][0]*m2[0][0]+m1[0][1]*m2[1][0])%1000000;
    result[0][1] = (m1[0][0]*m2[0][1]+m1[0][1]*m2[1][1])%1000000;
    result[1][0] = result[0][1]; //=> 피보나치 수 행렬 특징임
    result[1][1] = (m1[1][0]*m2[0][1]+m1[1][1]*m2[1][1])%1000000;
    map.put(n, result);
    return result;
  }

  static Long[][] fib(long n) {
    //base condition
    if(n==1) {
      return matrix;
    }
    if (map.containsKey(n)) {
      return map.get(n);
    }
    //recursion
    if (n % 2 == 0) { //짝수
      return calMatrix(fib(n / 2), fib(n / 2),n);
    } else { //홀수
      return calMatrix(fib(n / 2), fib(n / 2 + 1),n);
    }
  }

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    long n = Long.parseLong(br.readLine());
    Long[][] result = fib(n);
    System.out.println(result[0][1]);
  }
}

주관적인 생각의 풀이이므로 틀린 부분 지적은 언제나 환영합니다⚡

댓글남기기