1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
| #include<bits/stdc++.h>
#define ll long long #define INF 2147483647 #define mod 998244353 #define px 11 #define mul(a, b) ((ll)(a) * (ll)(b) % mod)
int inp(){ char c = getchar(); while(c < '0' || c > '9') c = getchar(); int sum = 0; while(c >= '0' && c <= '9'){ sum = sum * 10 + c - '0'; c = getchar(); } return sum; }
char s[1000010], l[1000010], r[1000010]; ll f[1000010], sum[1000010]; int n, sl, sr, pl[1000010], pr[1000010], hshs[1000010], hshl[1000010], hshr[1000010], fpow[1000010];
bool ok1(int x){ if(x + sl - 1 > n) return false; if(sl == pl[x]) return true; return l[pl[x] + 1] < s[x + pl[x]]; }
bool ok2(int x){ if(x + sr - 1 > n) return false; if(pr[x] == sr) return true; return r[pr[x] + 1] > s[x + pr[x]]; }
int get_hsh(int *hsh, int l, int r){ return (hsh[r] - mul(hsh[l - 1], fpow[r - l + 1]) + mod) % mod; }
void solve(int *s, int *hsh, int len){ for(int i = 1; i <= n - len + 1; i++){ int l = 0; int r = len; while(l < r){ int mid = (l + r + 1) >> 1; if(get_hsh(hshs, i, i + mid - 1) == get_hsh(hsh, 1, mid)) l = mid; else r = mid - 1; } s[i] = l; } }
void gethash(char *str, int *hsh, int len){ for(int i = 1; i <= len; i++) hsh[i] = ((ll)(hsh[i - 1]) * (ll)(px) + str[i] - '0' + 1) % mod; }
int main(){ fpow[0] = 1; for(int i = 1; i <= 1000000; i++) fpow[i] = mul(fpow[i - 1], px); scanf("%s", s + 1); scanf("%s", l + 1); scanf("%s", r + 1); n = strlen(s + 1); sl = strlen(l + 1); sr = strlen(r + 1); gethash(l, hshl, sl); gethash(r, hshr, sr); gethash(s, hshs, n); f[0] = 1; solve(pl, hshl, sl); solve(pr, hshr, sr); for(int i = 0; i <= n; i++){ if(i){ sum[i] += sum[i - 1]; f[i] += sum[i]; f[i] %= mod; } if(s[i + 1] == '0'){ if(sl == 1 && l[1] == '0'){ f[i + 1] += f[i]; f[i + 1] %= mod; } continue; } if(sl < sr){ sum[sl + i + 1] += f[i]; sum[sl + i + 1] %= mod; sum[sr + i] += mod - f[i]; sum[sr + i] %= mod; } if(sl == sr){ if(ok1(i + 1) && ok2(i + 1)){ f[i + sl] += f[i]; f[i + sl] %= mod; } } else { if(ok1(i + 1)){ f[i + sl] += f[i]; f[i + sl] %= mod; } if(ok2(i + 1)){ f[i + sr] += f[i]; f[i + sr] %= mod; } } } std::cout << f[n] << std::endl; }
|