1 条题解
-
0
第三部分:题目分析与标准代码
1. 状态定义
这是一个分组背包模型。
- :在以 为根的子树中,恰好选择 个节点(其中必须包含 自己),所能获得的最大权值。
2. 状态转移
对于节点 的每一个子节点 : 我们将 子树看作一组物品。这组物品可以选择拿 个、 个 直到 个。 我们要将子树 的选法合并到 的状态中。
$$dp[u][j] = \max_{0 \le k < j} (dp[u][j-k] + dp[v][k]) $$- :当前 子树(包含已经合并过的其他子树)总共选多少人。倒序枚举。
- :分给当前子节点 的名额。
3. 复杂度优化(树形背包的 优化)
如果暴力枚举 和 到 ,复杂度是 。 优化策略:
- 的上限不需要到 ,只需要到 (当前已合并的子树大小)。
- 的上限不需要到 ,只需要到 。 这样优化后,总复杂度可以证明为 。
4. 标准代码 (C++14)
/** * 题目:孟尝君的门客 * 难度:GESP 6级 / 提高+ * 算法:树形背包 DP */ #include <iostream> #include <vector> #include <algorithm> using namespace std; // 开启 IO 优化 void fast_io() { ios::sync_with_stdio(false); cin.tie(NULL); } const int MAXN = 305; const int INF = 0x3f3f3f3f; // 足够大的数,代表不可达 vector<int> adj[MAXN]; int val[MAXN]; int sz[MAXN]; // 子树大小 // dp[u][j]: 以u为根的子树选j个(含u)的最大权值 int dp[MAXN][MAXN]; int N, K; void dfs(int u) { // 1. 初始化 sz[u] = 1; dp[u][1] = val[u]; // 选1个就是选自己 // 2. 遍历子节点(分组背包) for (int v : adj[u]) { dfs(v); // 3. 状态转移(背包合并) // 倒序枚举当前背包容量 j // 优化:j 的上限是 min(K, sz[u] + sz[v]) for (int j = min(K, sz[u] + sz[v]); j >= 2; --j) { // 枚举分给子树 v 的名额 k // k 的范围:1 到 sz[v] (v子树最多这么大) // 且 k < j (因为 u 自己至少占 1 个,留给 v 的最多 j-1) // 且 j-k <= sz[u] (留给 u 原本部分的不能超过 u 原本的大小) // 这里的 sz[u] 是指合并 v 之前的大小 // 为了代码简洁,通常只写主要边界,依赖 dp 初始化的 -INF 来过滤非法状态 for (int k = 1; k <= sz[v] && k < j; ++k) { if (dp[u][j-k] != -INF && dp[v][k] != -INF) { dp[u][j] = max(dp[u][j], dp[u][j-k] + dp[v][k]); } } } // 合并完一个子节点后,更新 u 的大小 sz[u] += sz[v]; } } int main() { fast_io(); if (!(cin >> N >> K)) return 0; int root = 0; // 初始化 DP 数组为 -INF for(int i=0; i<=N; i++) { for(int j=0; j<=K; j++) { dp[i][j] = -INF; } } for (int i = 1; i <= N; ++i) { int p; cin >> p >> val[i]; if (p == 0) root = i; else adj[p].push_back(i); } dfs(root); // 题目保证一定有解吗? // 如果 K > N,或者树结构导致无法选 K 个,dp[root][K] 仍为 -INF // 但题目说 K <= N,且是一棵树,所以一定能选出 K 个 cout << dp[root][K] << endl; return 0; }
第四部分:数据生成器
生成
1.in~10.in及其对应标准答案。包含链状、随机、负权值等情况。/** * GESP 6级 [孟尝君的门客] - 数据生成器 */ #include <iostream> #include <fstream> #include <vector> #include <algorithm> #include <cstdlib> #include <ctime> using namespace std; // ------------------------------------------ // 标准解法函数 (生成 .out) // ------------------------------------------ const int MAXN_S = 305; const int INF_S = 1e9; vector<int> adj_s[MAXN_S]; int val_s[MAXN_S]; int sz_s[MAXN_S]; int dp_s[MAXN_S][MAXN_S]; void dfs_solve(int u, int K) { sz_s[u] = 1; dp_s[u][1] = val_s[u]; for (int v : adj_s[u]) { dfs_solve(v, K); // 这里的 sz_s[u] 是合并 v 之前的大小 for (int j = min(K, sz_s[u] + sz_s[v]); j >= 2; --j) { for (int k = 1; k <= sz_s[v] && k < j; ++k) { if (dp_s[u][j-k] > -INF_S && dp_s[v][k] > -INF_S) { dp_s[u][j] = max(dp_s[u][j], dp_s[u][j-k] + dp_s[v][k]); } } } sz_s[u] += sz_s[v]; } } int solve(int N, int K, int root, const vector<pair<int, int>>& nodes, const vector<pair<int, int>>& edges) { for(int i=1; i<=N; i++) { adj_s[i].clear(); for(int j=0; j<=K; j++) dp_s[i][j] = -INF_S; } for(int i=1; i<=N; i++) val_s[i] = nodes[i-1].second; for(auto& e : edges) adj_s[e.first].push_back(e.second); dfs_solve(root, K); return dp_s[root][K]; } // 辅助函数 int randRange(int min, int max) { return rand() % (max - min + 1) + min; } int main() { srand(time(0)); cout << "Start generating data..." << endl; for (int i = 1; i <= 10; ++i) { string in_name = to_string(i) + ".in"; string out_name = to_string(i) + ".out"; ofstream fin(in_name); ofstream fout(out_name); int N, K; // 构造测试点 if (i == 1) { // 样例1 N = 5; K = 3; } else if (i == 2) { // 样例2 N = 6; K = 4; } else if (i == 3) { // 链状 N = 20; K = 10; } else if (i == 4) { // 菊花图 N = 20; K = 5; } else if (i <= 7) { // 小规模随机 N = randRange(30, 50); K = randRange(1, N); } else { // 大规模随机 N = randRange(200, 300); K = randRange(1, N); } vector<pair<int, int>> nodes(N); vector<pair<int, int>> edges; int root = 1; // 生成树结构:i 的父亲在 1~i-1 中选,保证 1 是根 vector<int> p(N + 1, 0); for(int k=2; k<=N; k++) { if(i == 3) p[k] = k - 1; // 链状 else if(i == 4) p[k] = 1; // 菊花 else p[k] = randRange(1, k - 1); // 随机 edges.push_back({p[k], k}); } // 生成权值 for(int k=1; k<=N; k++) { if (i == 1 && k <= 5) { // 样例1 数据 int v[] = {0, 10, 5, 6, 3, 4}; nodes[k-1] = {p[k], v[k]}; } else if (i == 2 && k <= 6) { // 样例2 数据 int v[] = {0, 100, -10, 20, 50, 10, 1000}; nodes[k-1] = {p[k], v[k]}; } else { // 随机权值,包含负数 nodes[k-1] = {p[k], randRange(-100, 100)}; } } // 写入输入 fin << N << " " << K << endl; for (int k=0; k<N; k++) { fin << nodes[k].first << " " << nodes[k].second << endl; } // 写入输出 fout << solve(N, K, root, nodes, edges) << endl; fin.close(); fout.close(); cout << "Generated Case " << i << endl; } cout << "Done!" << endl; return 0; }
- 1
信息
- ID
- 19299
- 时间
- 1000ms
- 内存
- 32MiB
- 难度
- 10
- 标签
- (无)
- 递交数
- 1
- 已通过
- 1
- 上传者