>
对于链(链可以跳过不看):
设 表示当前到了第 个人,不看/看电视的最小代价。
初始状态 ,目标状态 。
对于树:
设 表示以 为根的子树,根节点是否在一个联通块里的最小代价。
初始状态每个节点的 ,当 时,,否则 。
首先注意到答案是单调上升的,相邻两个答案之间的差值是不上升的。
当 时,差值递减的很快,当 时,有很长一段答案有相同的差值,并且差值之间相差 ,考虑根号分治。
时间复杂度 , 为 时 的值,需要加 dfs 序优化。
#include <bits/stdc++.h>
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) f -= (ch == '-') * 2;
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
return x * f;
}
const int N = 200010, M = N << 1;
int n;
char s[N];
namespace Chain {
int f[N][2], d[N];
int dp(int k) {
memset(f, 0x3f, sizeof f);
f[0][0] = 0;
for (int i = 1; i <= n; i ++ ) {
if (s[i] == '0') f[i][0] = min(f[i - 1][0], f[i - 1][1]);
f[i][1] = min(f[i - 1][0] + k + 1, f[i - 1][1] + 1);
}
return min(f[n][0], f[n][1]);
}
void solve() {
int b = sqrt(n), lst = 0;
for (int k = 1; k <= b; k ++ ) {
lst = dp(k);
printf("%d\n", lst);
}
int maxd = dp(b + 1) - dp(b);
d[maxd + 1] = b;
for (int i = maxd; i; i -- ) {
int l = b + 1, r = n, ans = b + 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (dp(mid) - dp(mid - 1) >= i) l = mid + 1, ans = mid;
else r = mid - 1;
}
d[i] = ans;
}
for (int i = maxd; i; i -- ) {
int cnt = d[i] - d[i + 1];
while (cnt -- ) lst = lst + i, printf("%d\n", lst);
}
}
}
namespace STD {
const int INF = 0x3f3f3f3f;
const int N = 200010;
int h[N], e[N << 1], ne[N << 1], idx;
int f[N][2], id[N], d[N], fa[N], ts;
void add(int a, int b) {
e[ ++ idx] = b, ne[idx] = h[a], h[a] = idx;
}
void dfsx(int u, int f) {
id[ ++ ts] = u;
for (int i = h[u]; i; i = ne[i]) {
int j = e[i];
if (j == f) continue;
fa[j] = u;
dfsx(j, u);
}
}
int dp(int k) {
for (int i = 1; i <= n; i ++ ) {
f[i][0] = INF, f[i][1] = 1;
if (s[i] == '0') f[i][0] = 0;
}
for (int i = n; i > 1; i -- ) {
int j = id[i], u = fa[j];
if (s[u] == '0') f[u][0] += min(f[j][0], f[j][1] + k);
f[u][1] += min(f[j][0], f[j][1]);
}
return min(f[1][1] + k, f[1][0]);
}
void solve() {
for (int i = 1; i < n; i ++ ) {
int a = read(), b = read();
add(a, b), add(b, a);
}
dfsx(1, -1);
int b = sqrt(n), lst = 0;
for (int k = 1; k <= b; k ++ ) {
lst = dp(k);
printf("%d\n", lst);
}
int maxd = dp(b + 1) - dp(b);
d[maxd + 1] = b;
for (int i = maxd; i; i -- ) {
int l = b + 1, r = n, ans = b + 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (dp(mid) - dp(mid - 1) >= i) l = mid + 1, ans = mid;
else r = mid - 1;
}
d[i] = ans;
}
for (int i = maxd; i; i -- ) {
int cnt = d[i] - d[i + 1];
while (cnt -- ) lst = lst + i, printf("%d\n", lst);
}
}
}
int main() {
n = read();
scanf("%s", s + 1);
STD::solve();
return 0;
}