【ARC058 F】Iroha Loves Strings
原題連結:https://atcoder.jp/contests/arc058/tasks/arc058_d
前言
網路上的題解感覺都寫得不是很完整,而 Atcoder 自己的官解又是日文,但我個人又覺得這題官解過份精湛所以乾脆打算久違的來發一篇文加深印象(?
天真做法
首先看完題目有一個天真做法如下:
dp[i][j]:=前 i 個字串組成長度 j 的最小字典序字串
這樣光每一排就得存下至多 O(K2) 個字元,每次轉移也可能要比較 K 次長度至多 K 的字串,這樣無論是時間還是空間複雜度都是 O(NK2),顯然會 TLE + MLE。
我認為這題官解厲害的地方是他不打算直接改狀態,而是直接從優化下手,一般而言 dp 優化都會想從轉移速度、或是壓狀態開始優化,但直接在這題想要這樣做的話就會發現很難行得通。
巧妙的點就在於,官解打算先從壓空間開始。
空間優化
有一個很直觀的空間壓法就是在每個 dp[i][j] 存下轉移來源和構造方法,這樣的話顯然能夠乾淨的將空間壓到 O(NK),不過這個方法的轉移會變得很難優化,我自己在 virtual 中有嘗試這個思路但失敗了,甚至後來發現可能要用 rolling hash + 倍增才有辦法讓轉移變成 O(logN+logK),又難寫又胖。
官解從「刪除無用答案」當成出發點,什麼叫無用答案呢?一個狀態 dp[i][j] 存的答案無用代表他符合以下兩點其中之一:
1. 無法利用 si+1∼sN 湊出長度為 K−j 的字串。
2. 存在另一個有值、滿足 1. 的 dp[i][j′],使得「他的字典序嚴格比 dp[i][j] 小」
不滿足 1. 就無用應該很顯然,但對於 2.,什麼叫嚴格小呢?舉個例子,字串「aba」就不嚴格小於「ababa」,因為有可能「aba」後面能接的字串太大了導致「ababa」反而勝出;而字串「aaa」就嚴格小於「ababa」,因為無論在「aaa」後面接什麼,他永遠都會打贏「ababa」。
如此一般過濾掉不需要的字串後,神奇的地方就來了,可以觀察到「對於任兩個滿足 1. 的 dp[i][j],若他們互相不為對方的前綴,那麼一定會有一個人打輸被刪掉」。接著就可以得出結論:「所有有值且活著的 dp[i][j] 都是最長那個 dp[i][j] 的前綴」
因此要怎麼壓空間呢?只要在每個 i 存下最長活著的那個 dp[i][j],我們令其為 CSi,新 dp[i][j] 們的值改存一般的 bool 就好了,即代表 CSi 長度 j 的前綴是否可以被 s1∼si 湊出來。
經過一連串觀察後我們把空間壓到 O(NK),但更巧妙的是這些觀察可以更進一步的幫助我們優化轉移速度。
轉移優化
思考一下轉移,我們可以用推的 DP 去想,當 dp[i] 要推到 dp[i+1] 時,要嘛可以直接把原有的字串繼承過去,要嘛就是加一個 si+1 在後面轉移過去。因此如果暴力轉移過去的話,會在 dp[i+1] 內得到兩種字串:
1. CSi 的前綴
2. CSi 的前綴加一個 si+1
根據前面的觀察,我們得從這些字串當中選出一群嚴格小的字串,並從中挑選出最長的來當成 CSi+1。也就是說,我們需要比較這些字串的大小關係!
樸素比較肯定會吃 TLE,用 rolling hash 配二分搜比較需要帶 log 又胖,因此官解採用了一個精湛的做法,也就是利用 Z-value 來達成 O(1) 比較。
考慮對 si+1$CSi 這個字串做 Z-value,仔細畫點比對關係圖就會發現經過一些判斷就可以 O(1) 獲得比較結果,無論是嚴格大於小於,還是前綴關係都可以(這裡我懶就不贅述XD)。
透過比較我們就可以找到 CSi+1,最後只需要重新篩一遍哪些字串可以活著把他們留下來就好了(活著的判斷同樣可以透過前面的 Z-value O(1) 得到)。
因此每次轉移至多只需要對 O(K) 個字串做 O(K) 次的比較,每次都是 O(1)。我們就成功將時間複雜度也壓到 O(NK) 了!
實作細節
這題的實作可能會有點煩躁,因為有不少 case 得判,我個人的做法是先寫好比較的黑盒子,讓黑盒子在嚴格小於或大於的時候回傳 ±1,前綴關係的時候回傳 0。這樣在寫 dp 的時候也會比較乾淨。
先找出 CSi 再篩合法字串也是我個人為了清晰思路搞出來的做法,實際上是可以邊找 CSi 邊篩的。
然後簡單提一下,其實我的 code 並沒有真的把所有違反「無用答案 1.」的答案刪掉,其實也就是會多留著一些 CSi 的前綴。不刪的原因是麻煩,而可行的理由當然是「多轉移不影響答案」。
最後提點寫這題遇到的問題,這題的想法雖然精妙但寫起來異常的容易寫錯Orz,我自己是 WA 了很久才 AC。而且這題的測資也很弱,我第一次 AC 的時候其實是假解,當下興致沖沖的來寫題解寫到一半就發現有地方怪怪的,不太清楚是官解漏寫還是只是單純寫得有點簡略(我不會日文),總之我上面有講詳細就不講是哪裡掉 case 了XD
附上 code:
#include<bits/stdc++.h> | |
using namespace std; | |
typedef long long ll; | |
typedef pair<int, int> pii; | |
typedef pair<ll, ll> pll; | |
#define X first | |
#define Y second | |
#define SZ(a) ((int)a.size()) | |
#define ALL(v) v.begin(), v.end() | |
#define pb push_back | |
const int MAXK = 10005; | |
bitset<MAXK> can[2005], dp[2005]; | |
string arr[2005], cs[2005]; | |
int z[30005]; | |
void make_z(const string &s) { | |
int l = 0, r = 0; | |
for(int i = 1;i < SZ(s);i++){ | |
for(z[i] = max(0, min(r - i + 1, z[i - l])); | |
i + z[i] < SZ(s) && s[i + z[i]] == s[z[i]];z[i]++); | |
if(i + z[i] - 1 > r)l = i, r = i + z[i] - 1; | |
} | |
} | |
// prefix[p] vs prefix[q] + arr[j] | |
int cmp(int i, int p, int q, int j) { | |
if (p <= q) | |
return 0; | |
int lcp = z[SZ(arr[j]) + 1 + q]; | |
if (lcp == SZ(arr[j]) || q + lcp >= p) | |
return 0; | |
if (arr[j][lcp] < cs[i][q + lcp]) | |
return 1; | |
return -1; | |
} | |
// prefix[p] + arr[j] vs prefix[q] + arr[j] | |
int cmp2(int i, int p, int q, int j) { | |
if (p > q) | |
return -cmp2(i, q, p, j); | |
int lcp = z[SZ(arr[j]) + 1 + p]; | |
if (lcp < q - p) { | |
if (lcp == SZ(arr[j])) | |
return 0; | |
if (cs[i][p + lcp] < arr[j][lcp]) | |
return 1; | |
return -1; | |
} | |
lcp = z[q - p]; | |
if (q - p + lcp == SZ(arr[j])) | |
return 0; | |
if (arr[j][lcp] < arr[j][q - p + lcp]) | |
return 1; | |
return -1; | |
} | |
int main() { | |
ios::sync_with_stdio(0), cin.tie(0); | |
int n, k; | |
cin >> n >> k; | |
for (int i = 1; i <= n; ++i) | |
cin >> arr[i]; | |
can[n + 1][0] = 1; | |
for (int i = n; i >= 1; --i) | |
can[i] = can[i + 1] | (can[i + 1] << SZ(arr[i])); | |
dp[0][0] = 1; | |
for (int i = 1; i <= n; ++i) { | |
int nxt = -1; | |
make_z(arr[i] + "$" + cs[i - 1]); | |
z[SZ(arr[i]) + SZ(cs[i - 1]) + 1] = 0; | |
// decide cs[i] | |
for (int j = 0; j + SZ(arr[i]) <= k; ++j) | |
if (dp[i - 1][j] && can[i + 1][k - SZ(arr[i]) - j]) { | |
if (nxt == -1) { | |
int res = cmp(i - 1, SZ(cs[i - 1]), j, i); | |
if (res == 1 || (res == 0 && j + SZ(arr[i]) > SZ(cs[i - 1]))) | |
nxt = j; | |
} | |
else if (cmp2(i - 1, nxt, j, i) >= 0) | |
nxt = j; | |
} | |
if (nxt == -1) {// s[i] is not a suffix of cs[i] | |
cs[i] = cs[i - 1]; | |
// valid "inherit" transition | |
dp[i] = dp[i - 1]; | |
// valid "append" transition | |
for (int j = 0; j + SZ(arr[i]) < SZ(cs[i]); ++j) | |
if (dp[i - 1][j] && can[i + 1][k - SZ(arr[i]) - j] && cmp(i - 1, SZ(cs[i]), j, i) == 0) | |
dp[i][j + SZ(arr[i])] = 1; | |
} | |
else { | |
cs[i] = cs[i - 1].substr(0, nxt) + arr[i]; | |
// valid "inherit" transition | |
for (int j = 0; j <= SZ(cs[i - 1]); ++j) | |
if (dp[i - 1][j] && cmp(i - 1, j, nxt, i) == 0) | |
dp[i][j] = 1; | |
// valid "append" transition | |
dp[i][SZ(cs[i])] = 1; | |
for (int j = 0; j + SZ(arr[i]) < SZ(cs[i]); ++j) | |
if (dp[i - 1][j] && can[i + 1][k - SZ(arr[i]) - j] && cmp2(i - 1, j, nxt, i) == 0) | |
dp[i][j + SZ(arr[i])] = 1; | |
} | |
} | |
cout << cs[n] << "\n"; | |
} |
留言
張貼留言