1 条题解

  • 0
    @ 2025-12-10 11:25:54

    第三部分:题目分析与标准代码

    1. 状态定义

    这是一个分组背包模型。

    • dp[u][j]dp[u][j]:在以 uu 为根的子树中,恰好选择 jj 个节点(其中必须包含 uu 自己),所能获得的最大权值。

    2. 状态转移

    对于节点 uu 的每一个子节点 vv: 我们将 vv 子树看作一组物品。这组物品可以选择拿 11 个、22\dots 直到 size[v]size[v] 个。 我们要将子树 vv 的选法合并到 uu 的状态中。

    $$dp[u][j] = \max_{0 \le k < j} (dp[u][j-k] + dp[v][k]) $$
    • jj:当前 uu 子树(包含已经合并过的其他子树)总共选多少人。倒序枚举。
    • kk:分给当前子节点 vv 的名额。

    3. 复杂度优化(树形背包的 O(N2)O(N^2) 优化)

    如果暴力枚举 jjkkKK,复杂度是 O(NK2)O(N \cdot K^2)。 优化策略:

    1. jj 的上限不需要到 KK,只需要到 size[u]size[u](当前已合并的子树大小)。
    2. kk 的上限不需要到 KK,只需要到 size[v]size[v]。 这样优化后,总复杂度可以证明为 O(N2)O(N^2)

    4. 标准代码 (C++14)

    /**
     * 题目:孟尝君的门客
     * 难度:GESP 6级 / 提高+
     * 算法:树形背包 DP
     */
    
    #include <iostream>
    #include <vector>
    #include <algorithm>
    
    using namespace std;
    
    // 开启 IO 优化
    void fast_io() {
        ios::sync_with_stdio(false);
        cin.tie(NULL);
    }
    
    const int MAXN = 305;
    const int INF = 0x3f3f3f3f; // 足够大的数,代表不可达
    
    vector<int> adj[MAXN];
    int val[MAXN];
    int sz[MAXN]; // 子树大小
    
    // dp[u][j]: 以u为根的子树选j个(含u)的最大权值
    int dp[MAXN][MAXN]; 
    int N, K;
    
    void dfs(int u) {
        // 1. 初始化
        sz[u] = 1;
        dp[u][1] = val[u]; // 选1个就是选自己
    
        // 2. 遍历子节点(分组背包)
        for (int v : adj[u]) {
            dfs(v);
    
            // 3. 状态转移(背包合并)
            // 倒序枚举当前背包容量 j
            // 优化:j 的上限是 min(K, sz[u] + sz[v])
            for (int j = min(K, sz[u] + sz[v]); j >= 2; --j) {
                
                // 枚举分给子树 v 的名额 k
                // k 的范围:1 到 sz[v] (v子树最多这么大)
                // 且 k < j (因为 u 自己至少占 1 个,留给 v 的最多 j-1)
                // 且 j-k <= sz[u] (留给 u 原本部分的不能超过 u 原本的大小)
                // 这里的 sz[u] 是指合并 v 之前的大小
                
                // 为了代码简洁,通常只写主要边界,依赖 dp 初始化的 -INF 来过滤非法状态
                for (int k = 1; k <= sz[v] && k < j; ++k) {
                    if (dp[u][j-k] != -INF && dp[v][k] != -INF) {
                        dp[u][j] = max(dp[u][j], dp[u][j-k] + dp[v][k]);
                    }
                }
            }
            // 合并完一个子节点后,更新 u 的大小
            sz[u] += sz[v];
        }
    }
    
    int main() {
        fast_io();
    
        if (!(cin >> N >> K)) return 0;
    
        int root = 0;
        // 初始化 DP 数组为 -INF
        for(int i=0; i<=N; i++) {
            for(int j=0; j<=K; j++) {
                dp[i][j] = -INF;
            }
        }
    
        for (int i = 1; i <= N; ++i) {
            int p;
            cin >> p >> val[i];
            if (p == 0) root = i;
            else adj[p].push_back(i);
        }
    
        dfs(root);
    
        // 题目保证一定有解吗?
        // 如果 K > N,或者树结构导致无法选 K 个,dp[root][K] 仍为 -INF
        // 但题目说 K <= N,且是一棵树,所以一定能选出 K 个
        cout << dp[root][K] << endl;
    
        return 0;
    }
    

    第四部分:数据生成器

    生成 1.in ~ 10.in 及其对应标准答案。包含链状、随机、负权值等情况。

    /**
     * GESP 6级 [孟尝君的门客] - 数据生成器
     */
    #include <iostream>
    #include <fstream>
    #include <vector>
    #include <algorithm>
    #include <cstdlib>
    #include <ctime>
    
    using namespace std;
    
    // ------------------------------------------
    // 标准解法函数 (生成 .out)
    // ------------------------------------------
    const int MAXN_S = 305;
    const int INF_S = 1e9;
    vector<int> adj_s[MAXN_S];
    int val_s[MAXN_S];
    int sz_s[MAXN_S];
    int dp_s[MAXN_S][MAXN_S];
    
    void dfs_solve(int u, int K) {
        sz_s[u] = 1;
        dp_s[u][1] = val_s[u];
        
        for (int v : adj_s[u]) {
            dfs_solve(v, K);
            // 这里的 sz_s[u] 是合并 v 之前的大小
            for (int j = min(K, sz_s[u] + sz_s[v]); j >= 2; --j) {
                for (int k = 1; k <= sz_s[v] && k < j; ++k) {
                    if (dp_s[u][j-k] > -INF_S && dp_s[v][k] > -INF_S) {
                        dp_s[u][j] = max(dp_s[u][j], dp_s[u][j-k] + dp_s[v][k]);
                    }
                }
            }
            sz_s[u] += sz_s[v];
        }
    }
    
    int solve(int N, int K, int root, const vector<pair<int, int>>& nodes, const vector<pair<int, int>>& edges) {
        for(int i=1; i<=N; i++) {
            adj_s[i].clear();
            for(int j=0; j<=K; j++) dp_s[i][j] = -INF_S;
        }
        
        for(int i=1; i<=N; i++) val_s[i] = nodes[i-1].second;
        for(auto& e : edges) adj_s[e.first].push_back(e.second);
    
        dfs_solve(root, K);
        return dp_s[root][K];
    }
    
    // 辅助函数
    int randRange(int min, int max) {
        return rand() % (max - min + 1) + min;
    }
    
    int main() {
        srand(time(0));
        
        cout << "Start generating data..." << endl;
    
        for (int i = 1; i <= 10; ++i) {
            string in_name = to_string(i) + ".in";
            string out_name = to_string(i) + ".out";
            ofstream fin(in_name);
            ofstream fout(out_name);
    
            int N, K;
            
            // 构造测试点
            if (i == 1) { // 样例1
                N = 5; K = 3;
            } else if (i == 2) { // 样例2
                N = 6; K = 4;
            } else if (i == 3) { // 链状
                N = 20; K = 10;
            } else if (i == 4) { // 菊花图
                N = 20; K = 5;
            } else if (i <= 7) { // 小规模随机
                N = randRange(30, 50);
                K = randRange(1, N);
            } else { // 大规模随机
                N = randRange(200, 300);
                K = randRange(1, N);
            }
    
            vector<pair<int, int>> nodes(N);
            vector<pair<int, int>> edges;
            int root = 1;
    
            // 生成树结构:i 的父亲在 1~i-1 中选,保证 1 是根
            vector<int> p(N + 1, 0);
            for(int k=2; k<=N; k++) {
                if(i == 3) p[k] = k - 1; // 链状
                else if(i == 4) p[k] = 1; // 菊花
                else p[k] = randRange(1, k - 1); // 随机
                
                edges.push_back({p[k], k});
            }
    
            // 生成权值
            for(int k=1; k<=N; k++) {
                if (i == 1 && k <= 5) { // 样例1 数据
                    int v[] = {0, 10, 5, 6, 3, 4}; 
                    nodes[k-1] = {p[k], v[k]};
                } else if (i == 2 && k <= 6) { // 样例2 数据
                    int v[] = {0, 100, -10, 20, 50, 10, 1000};
                    nodes[k-1] = {p[k], v[k]};
                } else {
                    // 随机权值,包含负数
                    nodes[k-1] = {p[k], randRange(-100, 100)};
                }
            }
    
            // 写入输入
            fin << N << " " << K << endl;
            for (int k=0; k<N; k++) {
                fin << nodes[k].first << " " << nodes[k].second << endl;
            }
    
            // 写入输出
            fout << solve(N, K, root, nodes, edges) << endl;
    
            fin.close();
            fout.close();
            cout << "Generated Case " << i << endl;
        }
        
        cout << "Done!" << endl;
        return 0;
    }
    
    • 1

    信息

    ID
    19299
    时间
    1000ms
    内存
    32MiB
    难度
    10
    标签
    (无)
    递交数
    1
    已通过
    1
    上传者