N개의 도시, 설치 가능한 케이블 M개, 그리고 K개의 발전소가 있다. 도시가 발전소에 연결되면 발전소에 의해 생산된 전기가 연결된 도시에 공급된다. N개의 도시 모두 전기를 공급받을 수 있도록 설치해야 하는 케이블의 최소 비용을 구해야 한다. 단, 한 개의 도시는 반드시 한 개의 발전소에 의해서만 전기를 공급받아야 한다.
한 개의 도시에 두 개 이상의 발전소가 연결되어 있지 않아야 한다는 조건이 없었으면 해당문제는 MST(Minumum Spanning Tree)를 구하는 과정으로 쉽게 해결할 수 있다. 그러나 해당 조건 때문에 단순히 Kruskal 또는 Prim 알고리즘을 구현하는 것만으로 해결이 불가하다. 조건에 위배되지 않는 방법은 크게 두 가지가 있을 것이다.
1. 전체 그래프에 대한 MST를 구하고 MST에서 필요하지 않은 간선만 제외하기
2. MST를 구하는 과정에서 조건을 고려하여 일부 과정 변경 또는 추가
1번 방법은 MST를 구성하는 각 간선의 유무에 따라 모든 도시마다 하나의 발전소에 의해서만 전기를 공급받는지 따져야 하므로 TLE인 것은 물론, 이를 구현하는 것 또한 쉽지 않다. 따라서 2번 방법을 선택했는데, MST를 구하는 과정에서 보고 있는 간선을 최소 비용 케이블에 포함시켜야 하는지를 결정할 수 있는 Kruskal 알고리즘의 Union-Find을 사용했다.
어차피 모든 도시들은 어떠한 발전소에 의해 전기를 공급받는지를 따질 필요 없이 임의의 발전소 하나에 의해서만 전기를 공급받으면 된다. 따라서 Union-Find 과정에서 발전소를 각각 다른 정점(원소)으로 취급하지 않고 같은 집합의 정점(원소)으로 고려할 수 있도록 처음에 발전소 정점(원소)의 parent를 모두 -1로 0으로 통일한다. 그러면 Kruskal 알고리즘 실행 과정에서 하나의 발전소에 의해 연결되는 도시와 아직 연결되지 않은 도시로 나뉠 것이다. 즉, 하나의 발전소에 의해 연결되는 도시는 모두 parent가 -1 0인 같은 집합에 속할 것이다.
구체적으로 과정을 서술하면 다음과 같다.
1. 모든 발전소의 parent는 -1 0으로, 나머지 도시 정점의 parent는 각 도시 정점 번호로 초기화한다. (기존 Union-Find에 추가한 과정)
2. 그래프를 구성하는 모든 간선을 비용 크기 오름차순으로 정렬한다.
3. 정렬한 간선 차례대로 간선이 연결하는 두 정점이 속한 집합이 서로 다르면 merge 하고, 그렇지 않으면 merge 한다. merge 가능하면 해당 간선의 비용을 지금까지의 케이블 설치 최소 비용에 더한다. 단, 두 정점이 merge 가능하고 두 정점 중 하나라도 정점의 parent가 -1이면 다른 한 정점의 parent도 -1 0으로 만든다.
parent 배열을 -1이 아니라 0으로 초기화해야 합니다. -1로 초기화하면 find_parent 함수를 재귀적으로 실행하여 자신의 parent를 찾아갈 때 인덱스로 -1이 들어와서 에러가 발생할 수 있습니다.
#include <cstdio>
#include <vector>
#include <tuple>
#include <algorithm>
using namespace std;
const int N_MAX = 1000;
int n, m, k;
int parent[N_MAX + 1];
vector<tuple<int, int, int>> edges;
int find_parent(int x){
if (parent[x] == x) return x;
return parent[x] = find_parent(parent[x]);
}
bool mergable(int x, int y){
x = find_parent(x);
y = find_parent(y);
if (x != y){
if (x == 0) {
parent[y] = x;
}
else {
parent[x] = y;
}
return true;
}
else {
return false;
}
}
int main(){
scanf("%d %d %d", &n, &m, &k);
for (int i = 1; i <= n; i++){
parent[i] = i;
}
for (int i = 0; i < k; i++){
int x; scanf("%d", &x);
parent[x] = 0;
}
for (int i = 0; i < m; i++){
int u, v, w; scanf("%d %d %d", &u, &v, &w);
edges.push_back({w, u, v});
}
sort(edges.begin(), edges.end());
int ans = 0;
for (int i = 0; i < m; i++){
int x, y, w; tie(w, x, y) = edges[i];
if (mergable(x, y)){
ans += w;
}
}
printf("%d\n", ans);
return 0;
}