길이가 N이고 영문 알파벳 중 처음 K개의 알파벳을 사용해서 만든 암호 중 안전한 암호의 개수를 구한다. 여기서 ‘안전한 암호’란 “ABCBC” 또는 “ABABC” 패턴이 없는 암호를 의미한다.
처음에는 여집합을 이용해서 구하는 방법을 고민했다. 전체 암호의 개수에서 안전하지 않은 암호의 개수를 구하면 안전한 암호의 개수를 알 수 있기 때문이다. 이처럼 여집합 아이디어를 이용해서 문제를 해결하는 경우가 종종 있다. 그러나 안전하지 않은 암호의 개수를 구하는 것이 쉽지 않음을 느꼈다. 암호에 같은 패턴이 여러 번 나오거나 앞뒤로 서로 겹쳐서 나오는 경우 (예: “CABABCBC”) 개수를 세는 것이 곤란하다. 그래서 이 문제는 직접 안전한 암호 개수를 구하는 것이 바람직하다고 판단했다.
막연하게 Dynamic Programming 기법을 사용하여 메모이제이션 할 배열을 설정하면 다음과 같을 것이다.
dp[N][L1][L2][L3][L4][L5]: 길이가 N인 안전한 암호 중 맨 마지막 다섯 글자가 “L1L2L3L4L5”인 것의 개수
하지만 한 자리에 들어올 수 있는 알파벳은 26가지이고, 다섯 자리에 관해 들어올 수 있는 알파벳 경우의 수를 모두 고려해야 한다. 따라서 106 * 265 원소 개수만큼의 메모리가 필요한데, 이는 당연히 메모리 제한을 초과한다.
그런데 사실 맨 마지막 다섯 자리에는 패턴을 구성하는 문자인 A, B, C가 아닌 이상 정확히 어떤 알파벳이 오는지 저장해야 할 필요가 없다. 따라서 다섯 글자의 각 자리에 네 가지 경우의 수만 온다고 볼 수도 있다. 그러나 이 방법도 106 * 45 원소 개수 만큼의 메모리를 필요로 하여 메모리 제한 초과를 불러 온다.
이제까지 고안한 방법들이 계속 메모리 부족 현상이 발생하는 이유는 맨 마지막 다섯자리에 각각 어떠한 문자가 오는지 그 모든 경우의 수를 고려하려는 시도 때문이다. 사실, 길이 1부터 N까지 차례대로 살펴보면서 안전한 암호가 될 수 있는 경우의 수를 구하는데, 문제에서 주어진 패턴이 마지막 부분에 포함이 되지 않도록 피해주기만 하면 된다. (여기서 왜 마지막 부분만 봐 주면 되는지 의문이 든다면, Dynamic Programming을 사용하여 길이가 1부터 N까지 차례대로 구해 나간다는 점을 이해해야 한다. 길이 N인 안전한 암호의 개수를 구해야 하는 시점에서 맨 마지막 글자를 제외하고 앞의 길이 N - 1 구간도 패턴을 포함하지 않은 안전한 암호임이 보장된다.)
dp[N][M]: 길이가 N이고 암호의 맨 끝 부분의 상태가 M인 안전한 암호의 개수
어떠한 경우에 어떤 문자가 들어올 때 문제에서 주어진 패턴이 포함되는지를 바탕으로 암호의 맨 끝 부분에 따른 전체 암호의 상태를 새롭게 정의했다. 이를 도식화하면 다음과 같다.
상태 1은 암호의 맨 끝 부분이 “A”인 상태
상태 2는 암호의 맨 끝 부분이 “AB”인 상태
상태 3은 암호의 맨 끝 부분이 “ABA”인 상태
상태 4는 암호의 맨 끝 부분이 “ABCB”인 상태
상태 5는 암호의 맨 끝 부분이 “ABA”인 상태
상태 6은 암호의 맨 끝 부분이 “ABAB”인 상태
상태 0은 상태 1 ~ 7 중 어떠한 상태에도 해당되지 않는 경우
이를 바탕으로 어떤 상태일 때 어떠한 한 개의 알파벳이 올 때 다음 상태가 어떤 상태가 되는지 관계를 파악했고, 아래는 이를 정리한 표이다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int N_MAX = (int)1e6;
const long long MOD = (long long)(1e9 + 9);
int n, k;
long long dp[N_MAX + 1][7];
int main(){
scanf("%d %d", &n, &k);
dp[0][0] = 1;
for (int i = 1; i <= n; i++){
dp[i][0] = dp[i - 1][0] * (k - 1);
for (int j = 1; j < 7; j++){
dp[i][0] += dp[i - 1][j] * (k - 2);
}
dp[i][1] = dp[i - 1][0] + dp[i - 1][1] + dp[i - 1][3] + dp[i - 1][4] + dp[i - 1][5];
dp[i][2] = dp[i - 1][1];
dp[i][3] = dp[i - 1][2];
dp[i][4] = dp[i - 1][3];
dp[i][5] = dp[i - 1][2] + dp[i - 1][6];
dp[i][6] = dp[i - 1][5];
for (int j = 0; j < 7; j++){
dp[i][j] %= MOD;
}
}
long long ans = 0;
for (int i = 0; i < 7; i++){
ans += dp[n][i];
}
ans %= MOD;
printf("%lld\n", ans);
return 0;
}
여기서 슬라이딩 윈도우(Sliding Window) 기법을 사용하면 사용하는 메모리를 더 절약할 수 있다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int N_MAX = (int)1e6;
const long long MOD = (long long)(1e9 + 9);
int n, k;
long long dp[2][7];
int main(){
scanf("%d %d", &n, &k);
dp[0][0] = 1;
for (int i = 1; i <= n; i++){
dp[1][0] = dp[0][0] * (k - 1);
for (int j = 1; j < 7; j++){
dp[1][0] += dp[0][j] * (k - 2);
}
dp[1][1] = dp[0][0] + dp[0][1] + dp[0][3] + dp[0][4] + dp[0][5];
dp[1][2] = dp[0][1];
dp[1][3] = dp[0][2];
dp[1][4] = dp[0][3];
dp[1][5] = dp[0][2] + dp[0][6];
dp[1][6] = dp[0][5];
for (int j = 0; j < 7; j++){
dp[1][j] %= MOD;
dp[0][j] = dp[1][j];
dp[1][j] = 0;
}
}
long long ans = 0;
for (int i = 0; i < 7; i++){
ans += dp[0][i];
}
ans %= MOD;
printf("%lld\n", ans);
return 0;
}