考虑DP,首先是期望逆推,这样不用折算概率。
设 $f(i,a,b,c)$ 表示砍了 $i$ 刀,$a$ 个一血奴隶主,$b$ 个二血奴隶主,$c$ 个三血奴隶主的期望伤害。
这里直接给出逆推的状态转移方程:
令:
$$
add = a + b + c < K
\\
cnt = a + b + c + 1
$$
有:
$$
\begin{aligned}
f(i,a,b,c) &= \frac{a}{cnt} f(i+1,a-1,b,c) + \frac{b}{cnt} f(i+1,a+1,b-1,c+add)
\\ &+ \frac{c}{cnt} f(i+1,a,b+1,c-1+add) + \frac{1}{cnt} (1 + f(i+1,a,b,c))
\end{aligned}
$$
由于 $n \leq 10^{18}$ 显然考虑矩阵优化。
我们先压缩状态,将 $(a,b,c)$ 标号,最多有 $165$ 个。
注意到有一个 $+1$ 的系数,我们把它用状态 $0$ 表示,跟着一起转移。
如何构造转移矩阵
代入矩阵乘法定义即可,具体来说,$k x \to y$,则转移矩阵 $arr[x][y] = k$。
那么 $n$ 次迭代后的结果即为 $a \times b^n$ 其中 $a$ 表示初始状态,$b$ 表示转移矩阵。
值得注意的是若 $tot$ 表示状态数,$a$ 为 $1 \times tot$ 的矩阵,$b$ 为 $tot \times tot$ 的矩阵,结果为 $1 \times tot$ 的矩阵。
一些代码实现细节
由于多组数据,我们可以先预处理 $b^{2^k}$,这部分时间复杂度是 $O(\log_2 n \cdot tot^3)$ 的。
后面快速幂的部分每次我们都调用预处理好的 $b^{2^k}$,这里的矩阵乘法是 $O(tot^2)$ 的。
你需要使用常数更小的矩阵乘法,这个差异是非常大的。
是的没错,每一种 $m$ 的转移都不一样……
对于这道题目,需要格外注意的是 $+1$ 的系数需要一起转移,即 $arr[0][0] = 1$。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| #include<bits/stdc++.h> using namespace std; #define ll long long const int mod=998244353; const int maxn=50; const int maxm=8; const int maxs=170; const int maxg=60; struct Mat{int a[maxs+5][maxs+5],N,M;}; ll n; int m,K; int id[maxm+5][maxm+5][maxm+5],tot; Mat bs[maxg+5]; inline int Fpow(int x,int y){ int res=1; for(;y;x=1ll*x*x%mod,y>>=1) if(y&1) res=1ll*res*x%mod; return res; } Mat operator*(const Mat &x,const Mat &y){ Mat z; z.N=x.N,z.M=y.M; for(int i=0;i<=z.N;i++) for(int j=0;j<=z.M;j++) z.a[i][j]=0; for(int i=0;i<=x.N;i++) for(int k=0;k<=x.M;k++) for(int j=0;j<=y.M;j++) z.a[i][j]=(z.a[i][j]+1ll*x.a[i][k]*y.a[k][j])%mod; return z; } inline void Solve(){ scanf("%lld",&n); Mat res; res.N=0,res.M=tot; res.a[0][0]=1; for(int i=1;i<=tot;i++) res.a[0][i]=0; for(int i=0;n;n>>=1,i++) if(n&1) res=res*bs[i]; if(m==1) printf("%d\n",res.a[0][id[1][0][0]]); else if(m==2) printf("%d\n",res.a[0][id[0][1][0]]); else printf("%d\n",res.a[0][id[0][0][1]]); } signed main(){ int T; scanf("%d%d%d",&T,&m,&K); if(m==1){ for(int i=0;i<=K;i++) id[i][0][0]=++tot; for(int i=0;i<=K;i++){ int cnt=Fpow(i+1,mod-2); if(i) bs[0].a[id[i-1][0][0]][id[i][0][0]]=1ll*i*cnt%mod; bs[0].a[id[i][0][0]][id[i][0][0]]=cnt; bs[0].a[0][id[i][0][0]]=cnt; } } else if(m==2){ for(int i=0;i<=K;i++) for(int j=0;i+j<=K;j++) id[i][j][0]=++tot; for(int i=0;i<=K;i++){ for(int j=0;i+j<=K;j++){ int cnt=Fpow(i+j+1,mod-2); int tmp=i+j<K; if(i) bs[0].a[id[i-1][j][0]][id[i][j][0]]=1ll*i*cnt%mod; if(j) bs[0].a[id[i+1][j-1+tmp][0]][id[i][j][0]]=1ll*j*cnt%mod; bs[0].a[id[i][j][0]][id[i][j][0]]=cnt; bs[0].a[0][id[i][j][0]]=cnt; } } } else{ for(int i=0;i<=K;i++) for(int j=0;i+j<=K;j++) for(int k=0;i+j+k<=K;k++) id[i][j][k]=++tot; for(int i=0;i<=K;i++){ for(int j=0;i+j<=K;j++){ for(int k=0;i+j+k<=K;k++){ int cnt=Fpow(i+j+k+1,mod-2); int tmp=i+j+k<K; if(i) bs[0].a[id[i-1][j][k]][id[i][j][k]]=1ll*i*cnt%mod; if(j) bs[0].a[id[i+1][j-1][k+tmp]][id[i][j][k]]=1ll*j*cnt%mod; if(k) bs[0].a[id[i][j+1][k-1+tmp]][id[i][j][k]]=1ll*k*cnt%mod; bs[0].a[id[i][j][k]][id[i][j][k]]=cnt; bs[0].a[0][id[i][j][k]]=cnt; } } } } bs[0].a[0][0]=1; bs[0].N=bs[0].M=tot; for(int i=1;i<=maxg;i++) bs[i]=bs[i-1]*bs[i-1]; while(T--) Solve(); return 0; }
|