문제 내용이 길고 복잡해서 처음에는 문제 풀기가 까다로웠다. 그런데 간단히 정리하면 다음과 같다.
단방향 그래프가 주어지고 각 정점 방문 시 얻을 수 있는 이익과 각 간선을 탔을 때의 비용이 주어졌을 때, 해당 그래프를 탐색해서 얻을 수 있는 순이익의 최댓값을 구한다.
앞의 작업 진척도에 관한 내용은 사실 크게 의미있는 내용은 아니다.
문제에서 주어지는 그래프가 트리라는 정보는 언급되지 않았다. 그리고 입력 부분에서 '동굴 a에서 b로 작업 장비를 들여놓는 데 c만큼의 비용이 든다'라고 했으므로 단방향 그래프로 나타낼 수 있다. 그래서 처음에는 위상정렬을 사용해야 하는 문제인 줄 알았다. 그런데 그래프에서 cycle이 존재하지 않는다는 전제가 없어서 함부로 위상정렬을 사용하여 풀 수가 없었다. 그러면 '강한 연결 요소(Strongly Connected Component)를 찾는 알고리즘을 사용해야 하나?'라는 생각도 들었지만, 너무 문제 풀이가 복잡해질 것 같아서 다른 방법으로 풀고자 했다.
사실 생각해보면 주어지는 그래프에 굳이 cycle이 있고 없고가 그렇게 중요하지 않다고 느꼈다. 문제에서 탐사를 끝낸 뒤 이전 동굴로 장비를 다시 돌려놓지 않는다고 했으므로 그래프의 leaf 노드를 방문하면 더 이상 탐사를 진행하지 않는다. 다시 말해서, 한 번 방문한 정점은 다시 방문할 수 없고 탐사를 시작하는 루트 노드부터 차례대로 직접 연결된 정점으로 갈 수 있을 만큼 탐사를 진행한다는 것이다. 따라서 간단히 DFS(Depth First Search)를 통해서 해결이 가능하다. 대신 정점을 방문했을 때 앞으로 더 갈 수 있는 여러 경로 중에서 이익의 최댓값을 찾아야 하므로 다이나믹 프로그래밍(Dynamic Programming)으로 각 정점에서 경로 끝까지 진행했을 때 얻을 수 있는 이익의 최댓값을 메모이제이션 하는 과정이 필요하다.
최대 이익을 얻을 수 있는 경로를 찾아서 출력하는 것이 관건인데, 경로의 모든 정점을 일일이 직접 저장해 줄 필요는 없다. DFS 탐색하면서 최대 이익 경로를 찾았을 때 현재 방문 정점에서 해당 경로를 탐색할 때 방문하게 되는 다음 정점 번호만 저장해주면 된다. 나중에 경로 상의 모든 정점 번호를 출력할 때도 루트 노드부터 path 값을 연쇄적으로 따라가서 유효하지 않은 정점 번호가 나올 때까지 정점 번호를 출력해주면 된다. 단, 이렇게 배열을 선언하여 어떤 정점에서 왔는지를 저장하는 방법은 이 문제처럼 그래프 탐색 경로가 cycle을 타지 않을 때 유효하다. cycle을 탈 수 있는 경우의 경로 탐색은 백트래킹(Back Tracking)으로 해야 하며, 이와 관련한 문제는 아래를 참고해 보면 좋다.
dp[i]: i번 정점에서 앞으로 더 진행할 수 있는 경로로 탐사를 진행했을 때 얻을 수 있는 최대 이익
value[i]: i번 정점에 도달했을 때 얻을 수 있는 이익(가치)
path[i]: 최대 이익을 얻을 수 있는 경로 상에서 i번 정점 다음으로 방문하게 되는 정점 번호
각 정점마다 도착했을 때 앞으로 더 진행하여 얻을 수 있는 최대 이익을 구해야 하는데, 이는 해당 정점의 자식 정점 중 가장 큰 dp 값을 찾아주면 된다. 단, 자식 정점으로 갈 때의 간선의 cost를 고려해야 하고 현재 정점에서 얻을 수 있는 이익을 더해야 한다. 이를 정리하면 다음과 같다.
#include <cstdio>
#include <vector>
#include <string.h>
#include <algorithm>
using namespace std;
using pp = pair<int, int>;
const int N_MAX = 2 * (int)1e4;
int t, n, e;
int dp[N_MAX + 1];
int path[N_MAX + 1];
int value[N_MAX + 1];
vector<pp> edges[N_MAX + 1];
int go(int now){
if (dp[now] != -1){
return dp[now];
}
int &ret = dp[now];
ret = 0;
for (auto e: edges[now]){
int next = e.first;
int cost = e.second;
int sum = go(next);
if (ret < sum - cost){
ret = sum - cost;
path[now] = next;
}
}
ret += value[now];
return ret;
}
int main(){
scanf("%d", &t);
while(t--){
scanf("%d %d", &n, &e);
memset(dp, -1, sizeof(dp));
for (int i = 1; i <= n; i++){
cnt[i] = 1;
path[i] = 0;
edges[i].clear();
}
for (int i = 1; i <= n; i++){
scanf("%d", &value[i]);
}
for (int i = 0; i < e; i++){
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
edges[a].push_back({b, c});
}
int ret = go(1);
vector<int> ans;
for (int i = 1; i != 0; i = path[i]){
ans.push_back(i);
}
printf("%d %d\n", ret, (int)ans.size());
for (int v : ans){
printf("%d ", v);
}
printf("\n");
}
return 0;
}