N명의 직원이 사수와 부사수 관계가 배정된다. 그런데 모든 직원에게는 사수가 한 명씩만 배정된다고 했으므로 N명의 직원을 정점으로 고려할 때 그래프는 트리 형태를 지님을 알 수 있다. 사장 승범이인 root 정점을 제외한 직원 정점은 오직 하나의 부모 정점만 갖게 되고, 이는 임의의 서로 다른 두 정점 사이의 경로는 하나밖에 없음이 보장되기 때문이다.
사수와 부사수 관계에 있는 두 직원을 각각 서로 멘토와 멘티 관계로 맺게 하는데, 한 직원은 최대 1개의 멘토링 관계에만 속할 수 있다. 여기서 이 문제는 트리에서의 Dynamic Programming을 사용하는 트리 색칠하기 문제와 닮았다는 사실을 알 수 있다.
목표는 이 그래프에서 가능한 멘토 멘티 관계의 모든 시너지 합을 구하는 것이다. 중요한 점은 크게 두 가지로 말할 수 있다.
1. 어떤 두 노드를 멘토 멘티 관계로 정하느냐에 따라서 그 하위 자식들의 관계도 달라져서 시너지의 최댓값이 달라질 수 있다.
2. 같은 부모 노드를 갖는 노드들의 시너지 최댓값은 서로 독립적이다.
트리에서의 dynamic programming은 대개 완전 탐색을 통해 구한다. DFS(Depth First Search)를 진행하면서 방문하는 노드에 관하여 해당 노드를 root 노드로 하는 서브 트리에서의 시너지 최댓값을 memoization 할 수 있다. 그러면 시간 복잡도는 방문하는 노드의 개수인 O(N)이 된다.
memoization을 위한 배열의 정의는 다음과 같이 한다.
dp[i][j]: i번째 노드를 root 노드로 하는 서브 트리에서 i번째 노드를 멘토로 정하느냐 그렇지 않느냐(j)에 따른 시너지 합의 최댓값
i번째 노드를 멘토로 정하는 경우 ↔ j == 1
i번째 노드를 멘토로 정하지 않는 경우 ↔ j == 0
DFS 과정에서의 현재 방문 중인 노드를 노드 A, 노드 A를 root 노드로 하는 서브 트리를 트리 A라고 하자.
1) 노드 A가 멘토로 선택되지 않은 경우 (j == 0)
이는 노드 A가 멘티이거나 아무 역할로 정해지지 않은 경우를 의미한다. 노드 A를 기준으로 노드 A의 각 자식들은 서로 독립적이므로 각 자식들의 시너지 합의 최댓값이 각각 크면 클수록 이득이다. (트리 A의 시너지 합의 최댓값이 보장된다.)
2) 노드 A가 멘토로 선택되는 경우 (j == 1)
노드 A의 자식 중에서 하나의 자식 노드는 무조건 멘티로 정해져야 한다. 앞서 말한 바처럼 같은 부모 노드를 갖는 노드들은 시너지 최댓값에 있어서 서로 독립적이므로 한 자식 노드가 멘티로 선택되어도 나머지 자식 노드는 영향을 받지 않는다.
코드에서 which 배열은 구현상의 편의를 위해 각 노드에 관하여 멘토로 선택되는지 여부를 저장하고자 선언했다.
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int N_MAX = 200000;
int n;
int a[N_MAX + 1];
int dp[N_MAX + 1][2];
int which[N_MAX + 1];
vector<int> edges[N_MAX + 1];
int go(int now, int check){
int& ret = dp[now][check];
if (ret != -1){
return ret;
}
ret = 0;
// 현재 노드가 어떠한 역할도 없거나 멘티로 선택된 경우
if (check == 0){
int sum = 0;
// 각 자식들은 서로 독립적이므로 가능한 모든 자식들에게서 오는 값들의 합이 클수록 좋다.
for (auto child: edges[now]){
sum += max(go(child, 0), go(child, 1));
}
ret = max(ret, sum);
}
// 현재 노드가 멘토로 선택되는 경우
else if (check == 1){
// 하나의 자식은 무조건 멘티여야 한다.
// 하나의 자식만 멘티로 설정이 되면 나머지 자식들은 서로 독립적이므로 오는 값들의 합을 크게 만들어줘야 한다.
int sum = 0;
for (auto child: edges[now]){
int temp1 = go(child, 0);
int temp2 = go(child, 1);
if (temp1 > temp2){
sum += temp1;
which[child] = 0;
}
else {
sum += temp2;
which[child] = 1;
}
}
for (auto child: edges[now]){
int temp = sum - go(child, which[child]);
temp += (go(child, 0) + a[now] * a[child]);
ret = max(ret, temp);
}
}
return ret;
}
int main(){
memset(dp, -1, sizeof(dp));
scanf("%d", &n);
for (int i = 2; i <= n; i++){
int boss; scanf("%d", &boss);
edges[boss].push_back(i);
}
for (int i = 1; i <= n; i++){
scanf("%d", &a[i]);
}
int ans = 0;
for (int i = 0; i < 2; i++){
ans = max(ans, go(1, i));
}
printf("%d\n", ans);
return 0;
}