有趣的矩阵乘法
(为方便,下文中“大号宝石”代指连续的m m m 个分裂出来的宝石,“小号宝石”代指未分裂的单个宝石)
首先,我们观察这题,考虑D P DP D P ,设状态f i f_i f i 表示已经取了i i i 个单元的方案数的不难推出一个朴素的O ( n 2 ) D P O(n^2) DP O ( n 2 ) D P 方程f i = ∑ i − j ≥ m f j + 1 f_i = \displaystyle\sum_{i - j \geq m} f_j +1 f i = i − j ≥ m ∑ f j + 1 (可以理解成上一个大号宝石放的位置,最后一个1 1 1 即为全部用小号宝石填满的方案)
我们再仔细看看这个式子,加个前缀和,不难优化到O ( n ) O(n) O ( n ) ,然而数据范围n ≤ 1 0 18 n \leq 10^{18} n ≤ 1 0 1 8 ,这让我们考虑O ( log n ) O(\log n) O ( log n ) 级别的算法,我们接下来考虑矩阵乘法优化这个式子。
显然,这个式子跟满足i − j ≥ m i - j \geq m i − j ≥ m 的j j j 有关,但是这些数字的数量是n n n 级别的,我们考虑将∑ i − j ≥ m f j \displaystyle\sum_{i - j \geq m} f_j i − j ≥ m ∑ f j 变形,变成∑ j = 1 i f j − ∑ j = i − m j < i f j \displaystyle\sum_{j = 1}^{i} f_j - \sum_{j = i - m}^{j < i} f_j j = 1 ∑ i f j − j = i − m ∑ j < i f j 这样只要我们维护一下∑ j = 1 i f j \displaystyle\sum_{j = 1}^{i} f_j j = 1 ∑ i f j 就可以把需要维护的值的数量降到m m m 级别。
接下来直接在矩阵的第一行的第j ( 1 ≤ j ≤ m ) j (1 \leq j \leq m) j ( 1 ≤ j ≤ m ) 位放上f i − j f_{i - j} f i − j ,然后第m + 1 m + 1 m + 1 位维护∑ j = 1 i f j \displaystyle\sum_{j = 1}^{i} f_j j = 1 ∑ i f j ,第m + 2 m + 2 m + 2 位再弄个1 1 1 ,瞎构造一通转移矩阵,就可以愉快的套矩阵快速幂板子了,最终复杂度O ( m 3 log n ) O(m^3 \log n) O ( m 3 log n ) 。
(转移的矩阵的具体构造建议看代码)
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 #include <bits/stdc++.h> #define ll long long #define INF 2147483647 #define mod 1000000007 ll inp () { char c = getchar(); while (c < '0' || c > '9' ) c = getchar(); ll sum = 0 ; while (c >= '0' && c <= '9' ){ sum = sum * 10 + c - '0' ; c = getchar(); } return sum; } class Square { public : long long num[110 ][110 ]; int len; Square operator *(Square b){ Square ans; memset (ans.num, 0 , sizeof (ans.num)); for (int i = 1 ;i <= len; i++) for (int j = 1 ; j <= len; j++) for (int k = 1 ; k <= len; k++){ ans.num[i][j] += num[i][k] * b.num[k][j]; ans.num[i][j] %= mod; } ans.len = len; return ans; } }; int main () { ll n = inp(); int m = inp() - 1 ; Square a, b; a.len = m + 2 ; b.len = m + 2 ; memset (a.num, 0 , sizeof (a.num)); a.num[1 ][m + 2 ] = 1 ; a.num[1 ][1 ] = 1 ; a.num[1 ][m + 1 ] = 1 ; memset (b.num, 0 , sizeof (b.num)); for (int i = 1 ; i <= m; i++) b.num[i][m + 1 ] = -1 ; b.num[m + 1 ][m + 1 ] = 2 ; b.num[m + 2 ][m + 1 ] = 1 ; b.num[m + 2 ][m + 2 ] = 1 ; for (int i = 2 ; i <= m; i++) b.num[i - 1 ][i] = 1 ; for (int i = 1 ; i <= m; i++) b.num[i][1 ] = -1 ; b.num[m + 2 ][1 ] = b.num[m + 1 ][1 ] = 1 ; while (n){ if (n & 1 ) a = a * b; b = b * b; n >>= 1 ; } std ::cout << a.num[1 ][1 ] << std ::endl ; }