转化一下题面,要求划分一个颜色数最少的连通块,使其的颜色与未在连通块中的颜色无交。

考虑暴力怎么做,钦定一个点,将它的颜色加入队列,那么连通块必须包含这个颜色的所有点到这个点的路径,其中可能找到另外的颜色,加入队列,继续操作。

找路径的操作暴力跳父亲即可,如果一个点被访问过了就退出(因为上面的部分一定会被别的点访问)。

那么每一个点都只会被访问一次,做一次的时间复杂度是 $O(n)$。

做一次的时间复杂度似乎无法优化,我们考虑如何少做几次或者减少做的规模。

可以使用点分治,每次以分治中心为根,在其点分树的子树中做一次。

因为如果取外面的点一定不优,因为一定跨过了上一层分治中心。

注意需要判断做的范围是否包含需要颜色的所有点。

时间复杂度 $O(n \log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5;
int n,K;
vector<int> E[maxn+5];
int col[maxn+5];
vector<int> C[maxn+5];
int tot;
vector<int> tsz[maxn+5];
bool vis[maxn+5];
int sz[maxn+5],mx[maxn+5],wg;
vector<int> V;
int cnt[maxn+5];
bool flg[maxn+5];
int fa[maxn+5];
queue<int> qu;
int ans=INT_MAX;
void Find(int u,int F){
sz[u]=1,mx[u]=0;
for(int v:E[u]){
if(v==F||vis[v]) continue;
Find(v,u);
sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],tot-sz[u]);
if(wg==0||mx[u]<mx[wg]) wg=u;
}
void Dfs(int u,int F){
fa[u]=F,sz[u]=1;
V.emplace_back(u);
for(int v:E[u]){
if(v==F||vis[v]) continue;
Dfs(v,u);
sz[u]+=sz[v];
}
}
void Solve(int u){
wg=0,Find(u,0);
V.clear(),Dfs(wg,0);
for(int v:V) cnt[col[v]]=0,flg[v]=0;
for(int v:V) cnt[col[v]]++;
while(qu.size()) qu.pop();
qu.push(col[wg]);
int res=0;
while(qu.size()){
int t=qu.front();qu.pop();
if(!cnt[t]) continue;
if(cnt[t]<(int)C[t].size()){res=INT_MAX;break;}
res++;
for(int v:C[t]){
flg[v]=1,cnt[t]--;
for(int w=fa[v];!flg[w];w=fa[w])
flg[w]=1,qu.push(col[w]);
}
}
// cerr<<wg<<endl;
ans=min(ans,res);
int twg=wg;
vis[twg]=1;
for(int i=0;i<(int)E[twg].size();i++)
tsz[twg].emplace_back(sz[E[twg][i]]);
for(int i=0;i<(int)E[twg].size();i++){
if(vis[E[twg][i]]) continue;
tot=tsz[twg][i];
Solve(E[twg][i]);
}
}
signed main(){
// freopen("01-07.in","r",stdin);
int u,v;
scanf("%d%d",&n,&K);
for(int i=1;i<n;i++){
scanf("%d%d",&u,&v);
E[u].emplace_back(v);
E[v].emplace_back(u);
}
for(int i=1;i<=n;i++){
scanf("%d",&col[i]);
C[col[i]].emplace_back(i);
}
tot=n;
Solve(1);
printf("%d",ans-1);
return 0;
}