Link

https://www.acmicpc.net/problem/2213

Detail

로직 자체는 매우 간단한 점화식으로 해결이 가능하다. 인접하지 않은 점들만 모아야 한다고 하는 문장에서 레드블랙트리의 이미지가 떠올라. 이를 기반으로 문제를 풀어나가면 된다고 생각했다. 다만 각 정점의 가중치가 1이라면 레드블랙트리처럼 정점의 depth가 짝수인 정점의 수와 홀수인 점의 수 중에서 큰 값이 정답이겠으나, 이 문제는 정점에 가중치가 존재한다.

따라서 이 가중치를 어떻게 더하냐에 따라 갈리기때문에 dp 점화식을 사용해야 한다. 점화식은 다음과 같다.

  1. 는 i가 루트 노드인 트리(또는 서브트리)에서 구할 수 있는 독립집합의 최댓값이다.
    1. 루트 노드를 독립집합에 포함시키지 않을 경우 (exclude_case)
    2. 루트 노드를 독립집합에 포함시킬 경우 (include_case)

1. Exclude_case

i의 자식 노드 j들에 대해 다음과 같은 식이 성립한다.

2. Include_case

i의 두번째 자식 (코드에서는 grand_child라고 표현) 노드 k에 다음과 같은 식이 성립한다.

단순히 합만 출력한다면 훨씬 간단한 문제가 되었겠으나, 포함한 노드들 역시 출력해야 했기 때문에 tabulation 방식을 택하여 문제를 풀었다.

Code

#include <bits/stdc++.h>
 
using namespace std;
typedef long long ll;
 
vector<int> values;
vector<vector<int>> tree;
 
// {depth, node}
vector<pair<int, int>> depth_info;
 
// dp[i] = i번째 노드를 루트로 하는 서브트리의 계산 최댓값, 그리고 그 값을 구성하는 노드들의 번호
vector<pair<int, vector<int>>> dp;
 
int main() {
    int v_cnt; cin >> v_cnt;
 
    values.resize(v_cnt);
    tree.resize(v_cnt);
    depth_info.resize(v_cnt);
 
    vector<bool> visited(v_cnt, false);
    for (int i = 0; i < v_cnt; i++) {
        cin >> values[i];
    }
    visited[0] = true;
 
    for (int i = 0; i < v_cnt - 1; i++) {
        int u, v; cin >> u >> v;
        u--; v--;
 
        if (visited[u]) {
            tree[u].push_back(v);
            depth_info[v] = {depth_info[u].first + 1, v};
            visited[v] = true;
        } else {
            tree[v].push_back(u);
            depth_info[u] = {depth_info[v].first + 1, u};
            visited[u] = true;
        }
    }
 
    // sort by depth
    sort(depth_info.begin(), depth_info.end(), greater<>());
 
    //debug
    // for (auto& [depth, node] : depth_info) {
    //     cout << depth << " " << node + 1 << endl;
    // }
 
    // build dp
    dp.resize(v_cnt);
    for (int i = 0; i < v_cnt; i++) {
        int node = depth_info[i].second;
        int depth = depth_info[i].first;
 
        // leaf node
        if (tree[node].empty()) {
            dp[node] = {values[node], {node}};
            continue;
        }
 
        int include_node = values[node];
        vector<int> include_nodes = {node};
        for (int child : tree[node]) {
            for (int grand_child : tree[child]) {
                    include_node += dp[grand_child].first;
                    include_nodes.insert(include_nodes.end(), dp[grand_child].second.begin(), dp[grand_child].second.end());
            }
        }
 
        int exclude_node = 0;
        vector<int> exclude_nodes;
        for (int child : tree[node]) {
            exclude_node += dp[child].first;
            exclude_nodes.insert(exclude_nodes.end(), dp[child].second.begin(), dp[child].second.end());
        }
 
        if (include_node > exclude_node) {
            dp[node] = {include_node, include_nodes};
        } else {
            dp[node] = {exclude_node, exclude_nodes};
        }
    }
 
    // print answer
    cout << dp[0].first << endl;
    sort(dp[0].second.begin(), dp[0].second.end());
    for (int node : dp[0].second) {
        cout << node + 1 << " ";
    }
    cout << endl;
    
    return 0;
}