[개념] 행렬곱셈(쉬트라쎈)

행렬 곱셈의 효율적인 계산 방법인 쉬트라쎈 알고리즘에 대해 설명합니다. 이 글에서는 일반적인 행렬 곱셈 알고리즘(O(n³))과 쉬트라쎈 알고리즘의 개념, 동작 방식, 그리고 두 알고리즘의 성능 비교를 다룹니다. 쉬트라쎈 알고리즘은 행렬을 4개로 분할하여 7번의 곱셈과 18번의 덧셈/뺄셈으로 계산하며, 큰 행렬에서 더 효율적입니다.


단순한 행렬 곱셈 알고리즘과 좀 다른 방식인 쉬트라쎈 알고리즘을 소개한다.



행렬 곱셈(Matrix Multiplication)

C[i][j] = C[i][j] + A[i][k]*B[k][j] : n*n행렬인 A,B의 곱의 핵심 코드

  • 곱셈 방식 : ---*||| ___*||| 형식으로 곱

  • 복잡도 : n^3

    • for문이 3중이기 때문
    • 아래 코드를 확인

image-20221228193457336

  • 코드를 따라 해석해보면, 실제로 사람이 하는 행렬곱 방식과 동일하게 흘러감



쉬트라쎈의 알고리즘

여기선 개념과 동작방식만 참고하고 실제 구현코드는 구글링을 통해서 참고할 것

image-20221228194422555

  • 4개로 행렬을 분할해서 M공식들을 구해야한다는 점을 참고
  • 저렇게 M으로 구한 공식을 C행렬 공식에 그대로 대입하면 행렬곱의 해를 구할 수 있음


수도코드

void strassen(~~) {
if 분기점 이하
	기존 무작정 행렬 곱 사용(=기본 행렬 곱 알고리즘)
else 분기점 초과
	A, B행렬 4개 부분행렬로 분할(재귀 아님)
	쉬트라쎈 방법을 재귀적 구현 C = A*B 계산 // 예: strassen(n/2, A11+A22, B11+B22, M1);
	// 참고로 A11+A22란건 두 행렬 더했다는 것(M공식 위함)
	// n/2로 재귀되므로 분기점 이하에서 곱해질거임. 이를 M1으로 반환
	// 이 형식으로 M1~M7전부 가져오기
	// 구한 M1~M7로 진짜 C를 공식 이용해서 구함
	// 마지막에 구한 C11, C12...를 행렬 병합(합치는 함수 만들어 사용)
}
  • 임계값은 단순한 알고리즘보다 쉬트라쎈 알고리즘을 사용하면 더 좋을 것이라고 예상되는 지점을 의미
    • 이 임게값을 설정하는 이유는 아래 성능비교를 보면 알 수 있다.
    • 참고로 임계값은 컴퓨터마다 성능이 다르니까 조금씩 다 다르고, 보통 64인가 32이라고 한다.



성능 비교

2x2 행렬 A, B를 곱했을 때??

  • 기존 행렬 곱셈 알고리즘 : 8번 곱, 4번 덧셈
  • 쉬트라쎈 알고리즘 : 7번의 곱, 18번의 덧셈/뺄셈 (M1~M7과 C공식 통틀어서 횟수 구한 것)
  • 이를 보면 연산수가 행렬의 크기가 작을땐 기존 행렬 곱셈 알고리즘이 더 우수
  • 그러나, 행렬이 커지면?? 쉬트라쎈이 우수(분할해서 2*2로 만들어서 연산하면 되니까)
    • 기존 행렬의 경우 행렬이 커지면 곱셈 수가 많이 커져버림

댓글남기기