题解 CF1117D Magic Gems

有趣的矩阵乘法

(为方便,下文中“大号宝石”代指连续的mm个分裂出来的宝石,“小号宝石”代指未分裂的单个宝石)

首先,我们观察这题,考虑DPDP​,设状态fif_i​表示已经取了ii​个单元的方案数的不难推出一个朴素的O(n2)DPO(n^2) DP​方程fi=ijmfj+1f_i = \displaystyle\sum_{i - j \geq m} f_j +1​(可以理解成上一个大号宝石放的位置,最后一个11​即为全部用小号宝石填满的方案)

我们再仔细看看这个式子,加个前缀和,不难优化到O(n)O(n),然而数据范围n1018n \leq 10^{18},这让我们考虑O(logn)O(\log n)级别的算法,我们接下来考虑矩阵乘法优化这个式子。

显然,这个式子跟满足ijmi - j \geq m​jj​有关,但是这些数字的数量是nn​级别的,我们考虑将ijmfj\displaystyle\sum_{i - j \geq m} f_j​变形,变成j=1ifjj=imj<ifj\displaystyle\sum_{j = 1}^{i} f_j - \sum_{j = i - m}^{j < i} f_j​ 这样只要我们维护一下j=1ifj\displaystyle\sum_{j = 1}^{i} f_j​就可以把需要维护的值的数量降到mm​级别。

接下来直接在矩阵的第一行的第j(1jm)j (1 \leq j \leq m) ​ 位放上fijf_{i - j}​,然后第m+1m + 1​位维护j=1ifj\displaystyle\sum_{j = 1}^{i} f_j​,第m+2m + 2​位再弄个11​,瞎构造一通转移矩阵,就可以愉快的套矩阵快速幂板子了,最终复杂度O(m3logn)O(m^3 \log n)​

(转移的矩阵的具体构造建议看代码)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<bits/stdc++.h>

#define ll long long
#define INF 2147483647
#define mod 1000000007

ll inp(){
char c = getchar();
while(c < '0' || c > '9')
c = getchar();
ll sum = 0;
while(c >= '0' && c <= '9'){
sum = sum * 10 + c - '0';
c = getchar();
}
return sum;
}

class Square{
public:
long long num[110][110];
int len;
Square operator *(Square b){
Square ans;
memset(ans.num, 0, sizeof(ans.num));
for(int i = 1;i <= len; i++)
for(int j = 1; j <= len; j++)
for(int k = 1; k <= len; k++){
ans.num[i][j] += num[i][k] * b.num[k][j];
ans.num[i][j] %= mod;
}
ans.len = len;
return ans;
}
};

int main(){
ll n = inp();
int m = inp() - 1;
Square a, b;
a.len = m + 2;
b.len = m + 2;
memset(a.num, 0, sizeof(a.num));
a.num[1][m + 2] = 1;
a.num[1][1] = 1;
a.num[1][m + 1] = 1;
memset(b.num, 0, sizeof(b.num));
for(int i = 1; i <= m; i++)
b.num[i][m + 1] = -1;
b.num[m + 1][m + 1] = 2;
b.num[m + 2][m + 1] = 1;
b.num[m + 2][m + 2] = 1;
for(int i = 2; i <= m; i++)
b.num[i - 1][i] = 1;
for(int i = 1; i <= m; i++)
b.num[i][1] = -1;
b.num[m + 2][1] = b.num[m + 1][1] = 1;
// n -= 2;
while(n){
if(n & 1)
a = a * b;
b = b * b;
n >>= 1;
// printf("%lld\n", a.num[1][1]);
}
// printf("%lld\n", (a * (b * b) * b * b).num[1][1]);
// for(int i = 1; i <= n; i++){
// printf("%lld\n", a.num[1][1]);
// a = a * b;
// }
std::cout << a.num[1][1] << std::endl;
}

QQ

|

Codeforces

|

Luogu

|

Github
本站由 Hexo 驱动,使用 Azurus 作为主题。