Link
https://www.acmicpc.net/problem/2213
Detail
로직 자체는 매우 간단한 점화식으로 해결이 가능하다. 인접하지 않은 점들만 모아야 한다고 하는 문장에서 레드블랙트리의 이미지가 떠올라. 이를 기반으로 문제를 풀어나가면 된다고 생각했다. 다만 각 정점의 가중치가 1이라면 레드블랙트리처럼 정점의 depth가 짝수인 정점의 수와 홀수인 점의 수 중에서 큰 값이 정답이겠으나, 이 문제는 정점에 가중치가 존재한다.
따라서 이 가중치를 어떻게 더하냐에 따라 갈리기때문에 dp 점화식을 사용해야 한다. 점화식은 다음과 같다.
- 는 i가 루트 노드인 트리(또는 서브트리)에서 구할 수 있는 독립집합의 최댓값이다.
- 루트 노드를 독립집합에 포함시키지 않을 경우 (exclude_case)
- 루트 노드를 독립집합에 포함시킬 경우 (include_case)
1. Exclude_case
i의 자식 노드 j들에 대해 다음과 같은 식이 성립한다.
2. Include_case
i의 두번째 자식 (코드에서는 grand_child라고 표현) 노드 k에 다음과 같은 식이 성립한다.
단순히 합만 출력한다면 훨씬 간단한 문제가 되었겠으나, 포함한 노드들 역시 출력해야 했기 때문에 tabulation 방식을 택하여 문제를 풀었다.
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
vector<int> values;
vector<vector<int>> tree;
// {depth, node}
vector<pair<int, int>> depth_info;
// dp[i] = i번째 노드를 루트로 하는 서브트리의 계산 최댓값, 그리고 그 값을 구성하는 노드들의 번호
vector<pair<int, vector<int>>> dp;
int main() {
int v_cnt; cin >> v_cnt;
values.resize(v_cnt);
tree.resize(v_cnt);
depth_info.resize(v_cnt);
vector<bool> visited(v_cnt, false);
for (int i = 0; i < v_cnt; i++) {
cin >> values[i];
}
visited[0] = true;
for (int i = 0; i < v_cnt - 1; i++) {
int u, v; cin >> u >> v;
u--; v--;
if (visited[u]) {
tree[u].push_back(v);
depth_info[v] = {depth_info[u].first + 1, v};
visited[v] = true;
} else {
tree[v].push_back(u);
depth_info[u] = {depth_info[v].first + 1, u};
visited[u] = true;
}
}
// sort by depth
sort(depth_info.begin(), depth_info.end(), greater<>());
//debug
// for (auto& [depth, node] : depth_info) {
// cout << depth << " " << node + 1 << endl;
// }
// build dp
dp.resize(v_cnt);
for (int i = 0; i < v_cnt; i++) {
int node = depth_info[i].second;
int depth = depth_info[i].first;
// leaf node
if (tree[node].empty()) {
dp[node] = {values[node], {node}};
continue;
}
int include_node = values[node];
vector<int> include_nodes = {node};
for (int child : tree[node]) {
for (int grand_child : tree[child]) {
include_node += dp[grand_child].first;
include_nodes.insert(include_nodes.end(), dp[grand_child].second.begin(), dp[grand_child].second.end());
}
}
int exclude_node = 0;
vector<int> exclude_nodes;
for (int child : tree[node]) {
exclude_node += dp[child].first;
exclude_nodes.insert(exclude_nodes.end(), dp[child].second.begin(), dp[child].second.end());
}
if (include_node > exclude_node) {
dp[node] = {include_node, include_nodes};
} else {
dp[node] = {exclude_node, exclude_nodes};
}
}
// print answer
cout << dp[0].first << endl;
sort(dp[0].second.begin(), dp[0].second.end());
for (int node : dp[0].second) {
cout << node + 1 << " ";
}
cout << endl;
return 0;
}