其实我很想吐槽这题

我比赛的时候疯狂$\text{WA on test 16}​$

然后死活找不出来哪儿错了

我觉得我数组开小了

于是我开到$MLE$都没有过掉

最后一看

被一个n = 2的点卡掉了……

才出4题,罚时爆炸,上黄失败,掉分哭唧唧

首先我们看到题,我们可以想到一个朴素的$DP$方法:

我们用$f[i][j]$来表示在前$i$个里面取到和为$j$的方案数

然后我们可以列出$DP$方程,即$f[i][j] = \displaystyle\sum_{u = 1}^{k} f[i - 1][j - d[u]]$

那么显然,答案就是$\displaystyle \sum_{i = 1}^{n * 5}f[\frac{n}{2}][i] $

然而这个朴素的DP是$O(n^2)$的

我们考虑如何优化这个$DP$方程

我们想想看卷积的公式

$c[i] = \displaystyle \sum_{j = 0}^{i} a[j] * b[i - j]$

我们再试着改写一下原来的$DP$方程

我们令$g[u]$为$1$当且仅当$u \in d$,否则$g[u] = 0$

那么式子就可以改写成

$f[i][j] = \displaystyle\sum_{u = 1}^{j} f[i - 1][j - u] * g[u]$

我一看,这下子$f[i]$不就是$f[i - 1]$和$g$这两个多项式的卷积了吗

我们再考虑$f[0]$,显然$f[0] = {1, 0, 0, 0… }$

然后我们发现任何多项式乘上$f[0]$都等于它本身

于是我们发现最终的答案就是$g$这个多项式的$\frac{n}{2}$次幂

然后我们就可以愉快的用多项式快速幂解决这个问题了

注意多项式的长度要动态开,否则会T飞~

这里我用的是$NTT$,毕竟模数是$998244353$,取模比起$FFT$方便不少…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include<bits/stdc++.h>
// o^r^z w^x^w

#define ll long long
#define INF 2147483647
#define mod 998244353

int pi = acos(-1);
const int G = 3;

inline int inp(){
char c = getchar();
while(c < '0' || c > '9')
c = getchar();
int sum = 0;
while(c >= '0' && c <= '9'){
sum = sum * 10 + c - '0';
c = getchar();
}
return sum;
}

ll powmod(ll a, int b){
ll sum = 1;
while(b){
if(b & 1){
sum *= a;
sum %= mod;
}
a *= a;
a %= mod;
b >>= 1;
}
return sum;
}

int g[8000010], num[8000010];
int len, r[8000010];
int a[8000010], b[8000010], c[8000010];
int wn[8000010];
int invlim;
#define add(a, b) ((a) + (b) >= mod ? (a) + (b) - mod : (a) + (b))
#define mine(a, b) ((a) < (b) ? (a) - (b) + mod : (a) - (b))
#define mul(a, b) ((ll)(a) * (ll)(b) % mod)

inline void ntt(int *a, int f){
for(int i = 0; i < len; i++)
if(i < r[i])
std::swap(a[i], a[r[i]]);
for(int i = 1; i < len; i <<= 1){
wn[0] = 1;
wn[1] = powmod(G, (mod - 1) / (i << 1));
for(int j = 2; j < i; j++)
wn[j] = mul(wn[j - 1], wn[1]);
for(int j = 0; j < len; j += i << 1){
int *L = a + j;
int *R = L + i;
for(int k = 0; k < i; k++){
const int t = mul(wn[k], R[k]);
R[k] = mine(L[k], t);
L[k] = add(L[k], t);
}
}
}
if(f == -1){
std::reverse(a + 1, a + len);
for(int i = 0; i < len; i++)
a[i] = mul(a[i], invlim);
}
}

void init(int n){
len = 1;
int lg = 0;
while(len <= n)
len <<= 1, lg++;
for(int i = 0; i < len; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lg - 1));
invlim = powmod(len, mod - 2);
}

int main(){
int n = inp() >> 1;
int k = inp();
memset(g, 0, sizeof(g));
for(int i = 1; i <= k; i++){
int x = inp();
g[x] = 1;
}
int pownum = n - 1;
memcpy(num, g, sizeof(g));
init(20);
while(pownum){
// printf("%d\n", pownum);
if(pownum & 1){
int maxw = 0;
for(int i = 1; i <= 5000000; i++)
if(g[i] > 0)
maxw = i;
init(maxw << 2);
ntt(g, 1);
ntt(num, 1);
for(int i = 0; i <= len; i++)
num[i] = mul(num[i], g[i]);
ntt(g, -1);
ntt(num, -1);
}
int maxw = 0;
for(int i = 1; i <= 5000000; i++)
if(g[i] > 0)
maxw = i;
init(maxw << 2);
ntt(g, 1);
for(int i = 0; i <= len; i++)
g[i] = mul(g[i], g[i]);
ntt(g, -1);
pownum >>= 1;
}
int ans = 0;
for(int i = 0; i <= len; i++)
ans = add(ans, mul(num[i], num[i]));
printf("%d\n", ans);
}

 评论