有趣的矩阵乘法

(为方便,下文中“大号宝石”代指连续的$m$个分裂出来的宝石,“小号宝石”代指未分裂的单个宝石)

首先,我们观察这题,考虑$DP​$,设状态$f_i​$表示已经取了$i​$个单元的方案数的不难推出一个朴素的$O(n^2) DP​$方程$f_i = \displaystyle\sum_{i - j \geq m} f_j +1​$(可以理解成上一个大号宝石放的位置,最后一个$1​$即为全部用小号宝石填满的方案)

我们再仔细看看这个式子,加个前缀和,不难优化到$O(n)$,然而数据范围$n \leq 10^{18}$,这让我们考虑$O(\log n)$级别的算法,我们接下来考虑矩阵乘法优化这个式子。

显然,这个式子跟满足$i - j \geq m​$的$j​$有关,但是这些数字的数量是$n​$级别的,我们考虑将$\displaystyle\sum_{i - j \geq m} f_j​$变形,变成$\displaystyle\sum_{j = 1}^{i} f_j - \sum_{j = i - m}^{j < i} f_j​$ 这样只要我们维护一下$\displaystyle\sum_{j = 1}^{i} f_j​$就可以把需要维护的值的数量降到$m​$级别。

接下来直接在矩阵的第一行的第$j (1 \leq j \leq m) ​$ 位放上$f_{i - j}​$,然后第$m + 1​$位维护$\displaystyle\sum_{j = 1}^{i} f_j​$,第$m + 2​$位再弄个$1​$,瞎构造一通转移矩阵,就可以愉快的套矩阵快速幂板子了,最终复杂度$O(m^3 \log n)​$。

(转移的矩阵的具体构造建议看代码)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<bits/stdc++.h>

#define ll long long
#define INF 2147483647
#define mod 1000000007

ll inp(){
char c = getchar();
while(c < '0' || c > '9')
c = getchar();
ll sum = 0;
while(c >= '0' && c <= '9'){
sum = sum * 10 + c - '0';
c = getchar();
}
return sum;
}

class Square{
public:
long long num[110][110];
int len;
Square operator *(Square b){
Square ans;
memset(ans.num, 0, sizeof(ans.num));
for(int i = 1;i <= len; i++)
for(int j = 1; j <= len; j++)
for(int k = 1; k <= len; k++){
ans.num[i][j] += num[i][k] * b.num[k][j];
ans.num[i][j] %= mod;
}
ans.len = len;
return ans;
}
};

int main(){
ll n = inp();
int m = inp() - 1;
Square a, b;
a.len = m + 2;
b.len = m + 2;
memset(a.num, 0, sizeof(a.num));
a.num[1][m + 2] = 1;
a.num[1][1] = 1;
a.num[1][m + 1] = 1;
memset(b.num, 0, sizeof(b.num));
for(int i = 1; i <= m; i++)
b.num[i][m + 1] = -1;
b.num[m + 1][m + 1] = 2;
b.num[m + 2][m + 1] = 1;
b.num[m + 2][m + 2] = 1;
for(int i = 2; i <= m; i++)
b.num[i - 1][i] = 1;
for(int i = 1; i <= m; i++)
b.num[i][1] = -1;
b.num[m + 2][1] = b.num[m + 1][1] = 1;
// n -= 2;
while(n){
if(n & 1)
a = a * b;
b = b * b;
n >>= 1;
// printf("%lld\n", a.num[1][1]);
}
// printf("%lld\n", (a * (b * b) * b * b).num[1][1]);
// for(int i = 1; i <= n; i++){
// printf("%lld\n", a.num[1][1]);
// a = a * b;
// }
std::cout << a.num[1][1] << std::endl;
}

 评论