考虑DP,首先是期望逆推,这样不用折算概率。

设 $f(i,a,b,c)$ 表示砍了 $i$ 刀,$a$ 个一血奴隶主,$b$ 个二血奴隶主,$c$ 个三血奴隶主的期望伤害。

这里直接给出逆推的状态转移方程:

令:

$$
add = a + b + c < K
\\
cnt = a + b + c + 1
$$

有:

$$
\begin{aligned}
f(i,a,b,c) &= \frac{a}{cnt} f(i+1,a-1,b,c) + \frac{b}{cnt} f(i+1,a+1,b-1,c+add)
\\ &+ \frac{c}{cnt} f(i+1,a,b+1,c-1+add) + \frac{1}{cnt} (1 + f(i+1,a,b,c))
\end{aligned}
$$

由于 $n \leq 10^{18}$ 显然考虑矩阵优化。

我们先压缩状态,将 $(a,b,c)$ 标号,最多有 $165$ 个。

注意到有一个 $+1$ 的系数,我们把它用状态 $0$ 表示,跟着一起转移。

如何构造转移矩阵

代入矩阵乘法定义即可,具体来说,$k x \to y$,则转移矩阵 $arr[x][y] = k$。

那么 $n$ 次迭代后的结果即为 $a \times b^n$ 其中 $a$ 表示初始状态,$b$ 表示转移矩阵。

值得注意的是若 $tot$ 表示状态数,$a$ 为 $1 \times tot$ 的矩阵,$b$ 为 $tot \times tot$ 的矩阵,结果为 $1 \times tot$ 的矩阵。

一些代码实现细节

由于多组数据,我们可以先预处理 $b^{2^k}$,这部分时间复杂度是 $O(\log_2 n \cdot tot^3)$ 的。

后面快速幂的部分每次我们都调用预处理好的 $b^{2^k}$,这里的矩阵乘法是 $O(tot^2)$ 的。

你需要使用常数更小的矩阵乘法,这个差异是非常大的。

是的没错,每一种 $m$ 的转移都不一样……

对于这道题目,需要格外注意的是 $+1$ 的系数需要一起转移,即 $arr[0][0] = 1$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int mod=998244353;
const int maxn=50;
const int maxm=8;
const int maxs=170;
const int maxg=60;
struct Mat{int a[maxs+5][maxs+5],N,M;};
ll n;
int m,K;
int id[maxm+5][maxm+5][maxm+5],tot;
Mat bs[maxg+5];
inline int Fpow(int x,int y){
int res=1;
for(;y;x=1ll*x*x%mod,y>>=1)
if(y&1) res=1ll*res*x%mod;
return res;
}
Mat operator*(const Mat &x,const Mat &y){
Mat z;
z.N=x.N,z.M=y.M;
for(int i=0;i<=z.N;i++)
for(int j=0;j<=z.M;j++)
z.a[i][j]=0;
for(int i=0;i<=x.N;i++)
for(int k=0;k<=x.M;k++)
for(int j=0;j<=y.M;j++)
z.a[i][j]=(z.a[i][j]+1ll*x.a[i][k]*y.a[k][j])%mod;
return z;
}
inline void Solve(){
scanf("%lld",&n);
Mat res;
res.N=0,res.M=tot;
res.a[0][0]=1;
for(int i=1;i<=tot;i++) res.a[0][i]=0;
for(int i=0;n;n>>=1,i++)
if(n&1) res=res*bs[i];
if(m==1) printf("%d\n",res.a[0][id[1][0][0]]);
else if(m==2) printf("%d\n",res.a[0][id[0][1][0]]);
else printf("%d\n",res.a[0][id[0][0][1]]);
}
signed main(){
int T;
scanf("%d%d%d",&T,&m,&K);
if(m==1){
for(int i=0;i<=K;i++)
id[i][0][0]=++tot;
for(int i=0;i<=K;i++){
int cnt=Fpow(i+1,mod-2);
if(i) bs[0].a[id[i-1][0][0]][id[i][0][0]]=1ll*i*cnt%mod;
bs[0].a[id[i][0][0]][id[i][0][0]]=cnt;
bs[0].a[0][id[i][0][0]]=cnt;
}
}
else if(m==2){
for(int i=0;i<=K;i++)
for(int j=0;i+j<=K;j++)
id[i][j][0]=++tot;
for(int i=0;i<=K;i++){
for(int j=0;i+j<=K;j++){
int cnt=Fpow(i+j+1,mod-2);
int tmp=i+j<K;
if(i) bs[0].a[id[i-1][j][0]][id[i][j][0]]=1ll*i*cnt%mod;
if(j) bs[0].a[id[i+1][j-1+tmp][0]][id[i][j][0]]=1ll*j*cnt%mod;
bs[0].a[id[i][j][0]][id[i][j][0]]=cnt;
bs[0].a[0][id[i][j][0]]=cnt;
}
}
}
else{
for(int i=0;i<=K;i++)
for(int j=0;i+j<=K;j++)
for(int k=0;i+j+k<=K;k++)
id[i][j][k]=++tot;
for(int i=0;i<=K;i++){
for(int j=0;i+j<=K;j++){
for(int k=0;i+j+k<=K;k++){
int cnt=Fpow(i+j+k+1,mod-2);
int tmp=i+j+k<K;
if(i) bs[0].a[id[i-1][j][k]][id[i][j][k]]=1ll*i*cnt%mod;
if(j) bs[0].a[id[i+1][j-1][k+tmp]][id[i][j][k]]=1ll*j*cnt%mod;
if(k) bs[0].a[id[i][j+1][k-1+tmp]][id[i][j][k]]=1ll*k*cnt%mod;
bs[0].a[id[i][j][k]][id[i][j][k]]=cnt;
bs[0].a[0][id[i][j][k]]=cnt;
}
}
}
}
bs[0].a[0][0]=1;
bs[0].N=bs[0].M=tot;
for(int i=1;i<=maxg;i++) bs[i]=bs[i-1]*bs[i-1];
while(T--) Solve();
return 0;
}