哈希 + 二分 + DP

首先看到题面,很容易想到一个$DP$,令$f[i]$为划分到$i$为止的方案数。

然后朴素的暴力转移是$O(n^2)$的,非常显然一个状态$i$能够转移到的$j$是一段连续的,进而想到使用前缀和优化。

令$l$的长度为$sl$,$r$的长度为$sr$,那么长度为$len$($sl < len < sr$)的一段数字$s$必定满足$l < s < r$。然后那么我们只需要考虑当前状态$i$的时候$[i, i + sl)$和$[i,i + sr)$这两段区间分别和$l,r$的大小关系。

如果我们我们要比较两个字符串$a$和$b$的大小关系,我们完全先特判掉两个完全一样的情况,然后再求出它们的$lcp$,然后那么$a$和$b$的大小关系就是$a[lcp + 1]$和$b[lcp + 1]$的大小关系了。

$lcp$可以哈希 + 二分解决,接下来的事情就是$DP$的了

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include<bits/stdc++.h>

#define ll long long
#define INF 2147483647
#define mod 998244353
#define px 11
#define mul(a, b) ((ll)(a) * (ll)(b) % mod)

int inp(){
char c = getchar();
while(c < '0' || c > '9')
c = getchar();
int sum = 0;
while(c >= '0' && c <= '9'){
sum = sum * 10 + c - '0';
c = getchar();
}
return sum;
}

char s[1000010], l[1000010], r[1000010];
ll f[1000010], sum[1000010];
int n, sl, sr, pl[1000010], pr[1000010], hshs[1000010], hshl[1000010], hshr[1000010], fpow[1000010];

bool ok1(int x){
if(x + sl - 1 > n)
return false;
if(sl == pl[x])
return true;
return l[pl[x] + 1] < s[x + pl[x]];
}

bool ok2(int x){
if(x + sr - 1 > n)
return false;
if(pr[x] == sr)
return true;
return r[pr[x] + 1] > s[x + pr[x]];
}

int get_hsh(int *hsh, int l, int r){
return (hsh[r] - mul(hsh[l - 1], fpow[r - l + 1]) + mod) % mod;
}

void solve(int *s, int *hsh, int len){ // 求lcp
for(int i = 1; i <= n - len + 1; i++){
int l = 0;
int r = len;
while(l < r){
int mid = (l + r + 1) >> 1;
if(get_hsh(hshs, i, i + mid - 1) == get_hsh(hsh, 1, mid))
l = mid;
else
r = mid - 1;
}
s[i] = l;
}
}

void gethash(char *str, int *hsh, int len){
for(int i = 1; i <= len; i++)
hsh[i] = ((ll)(hsh[i - 1]) * (ll)(px) + str[i] - '0' + 1) % mod;
}

int main(){
fpow[0] = 1;
for(int i = 1; i <= 1000000; i++)
fpow[i] = mul(fpow[i - 1], px);
scanf("%s", s + 1);
scanf("%s", l + 1);
scanf("%s", r + 1);
n = strlen(s + 1);
sl = strlen(l + 1);
sr = strlen(r + 1);
gethash(l, hshl, sl);
gethash(r, hshr, sr);
gethash(s, hshs, n);
f[0] = 1;
solve(pl, hshl, sl);
solve(pr, hshr, sr);
// for(int i = 1; i <= n; i++)
// printf("%d ", pl[i]);
// putchar('\n');
// for(int i = 1; i <= n; i++)
// printf("%d ", pr[i]);
// putchar('\n');
for(int i = 0; i <= n; i++){
if(i){
sum[i] += sum[i - 1];
f[i] += sum[i];
f[i] %= mod;
}
// printf("f[%d] = %d\n", i, f[i]);
if(s[i + 1] == '0'){
if(sl == 1 && l[1] == '0'){
f[i + 1] += f[i];
f[i + 1] %= mod;
}
continue;
}
if(sl < sr){
sum[sl + i + 1] += f[i];
sum[sl + i + 1] %= mod;
sum[sr + i] += mod - f[i];
sum[sr + i] %= mod;
}
if(sl == sr){
if(ok1(i + 1) && ok2(i + 1)){
f[i + sl] += f[i];
f[i + sl] %= mod;
}
} else {
if(ok1(i + 1)){
f[i + sl] += f[i];
f[i + sl] %= mod;
}
if(ok2(i + 1)){
f[i + sr] += f[i];
f[i + sr] %= mod;
}
}
// printf("f[%d] = %d\n", i, f[i]);
}
std::cout << f[n] << std::endl;
}

 评论