查看原题请戳这里
导入 我们先来看一个看似正确的分组背包的方法:
我们将每一列拆分为n个物品,第$m$行第$k$个物品的价值是$\sum_{i=k}^na[i][m]$,其代价为$\sum_{i=k}^n[pd[i][m]=1]$。
简单的说,第$m$行第$k$个物品的价值是第$m$行$k-n$个砖块价值的总和,代价为打完这些砖块需要的子弹数。
然后,我们就以此来跑一边分组背包,于是就愉快的暴0了。
为什么呢?
因为如果当前一个标记为$Y$的砖块在最下方,但是我们手中并没有子弹了,那么虽然这个砖块在打前和打后我们拥有的子弹数不变,即代价为$0$,但我们却已经没有办法去打这个砖块了。
正解 预处理 我们可以发现,如果我们当前有$Y$在最下方,而我们最新打的一个砖块的标记为$N$,那么我们完全可以先不打$N$,而是先打$Y$,然后再用新获得的子弹去打那个$N$。
由于当某个标记为$Y$的砖块在最下方时,直接去打掉这个砖块肯定是最优的,所以我们可以贪心地把所有的Y都压在一起。更确切的,我们是把这些标记为$Y$的砖块压到了这些砖块下方的那个砖块。根据引入中提到的那个问题,由于我们打完$N$以后可能恰好用完了所有的子弹,所以我们用$v[i][j][0]$表示第$i$列用$j$发子弹且最后一发子弹打到了$N$上能获得的价值,用$v[i][j][1]$表示第$i$列用$j$发子弹且最后一发子弹打到了$Y$上时获得的价值。
状态设计 我们用$f[i][j][0]$表示前$i$行用$j$发子弹且最后一发子弹打到了标记为$N$的砖块能获得的最大价值,$f[i][j][1]$表示前$i$行用$j$发子弹且最后一发子弹打到了标记为$Y$的砖块能获得的最大价值。
状态转移 先贴一波代码:
1 2 3 4 5 6 7 8 for (int i = 1 ; i <= m; i++) for (int j = 0 ; j <= k; j++) for (int l = 0 ; l <= min(n,j); l++) { f[i][j][1 ] = max(f[i][j][1 ],f[i - 1 ][j - l][1 ] + v[i][l][1 ]); if (l) f[i][j][0 ] = max(f[i][j][0 ],f[i - 1 ][j - l][1 ] + v[i][l][0 ]); if (j > l) f[i][j][0 ] = max(f[i][j][0 ],f[i - 1 ][j - l][0 ] + v[i][l][1 ]); }
其中$i$是枚举到了前$i$列,$j$是前$i$列共用了$j$发子弹,$l$是第$j$列用了$l$发子弹。
1 f[i][j][1 ] = max(f[i][j][1 ],f[i - 1 ][j - l][1 ] + v[i][l][1 ]);
这个转移是说我从$1$到$j-1$列借一发子弹(从最后一发子弹达到标记为$Y$的砖块进行转移,这样才能借到剩余的子弹),先用原本分配给这一列的$l$枚子弹打完所以能打的$N$,然后再用借来的子弹把所以压缩到这个$N$上的$Y$打掉(特殊的,如果这个$N$后面没有$Y$,那我就不打,这样无论如何最终我都会剩余一颗子弹没有用)。
1 if (l) f[i][j][0 ] = max(f[i][j][0 ],f[i - 1 ][j - l][1 ] + v[i][l][0 ]);
这个转移是说如果我分配给了第$j$列了子弹(若$l=0$,则我并没有消耗子弹去打第$j$列的砖块,那么这个转移没有意义),我如果从$1$到$i-1$列借子弹能获得的最大价值。
1 if (j > l) f[i][j][0 ] = max(f[i][j][0 ],f[i - 1 ][j - l][0 ] + v[i][l][1 ]);
这个转移是说如果我让第$1$到$i-1$列消耗了一定量的子弹,且不从其中某列借子弹,第$1$到$i-1$列能够获得的最大价值。
注:在前两段中所说的消耗子弹
是指打完某些砖块后总子弹数变少,只打标记为$Y$的砖块不算消耗子弹。
代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 #include <iostream> #include <cstring> #include <cstdio> #include <cmath> #include <algorithm> #define ll long long #define INF 0x7fffffff #define re register using namespace std ;int read () { register int x = 0 ,f = 1 ;register char ch; ch = getchar(); while (ch > '9' || ch < '0' ){if (ch == '-' ) f = -f;ch = getchar();} while (ch <= '9' && ch >= '0' ){x = x * 10 + ch - 48 ;ch = getchar();} return x * f; } int n,m,k,cnt,a[205 ][205 ],b[205 ][205 ],v[205 ][205 ][2 ],f[205 ][205 ][2 ];char c;int main () { n = read(); m = read();k = read(); for (int i = 1 ; i <= n; i++) for (int j = 1 ; j <= m; j++) { cin >> a[i][j] >> c; if (c == 'Y' ) b[i][j] = 1 ; } for (int i = 1 ; i <= m; i++) { cnt = 0 ; for (int j = n; j >= 1 ; j--) { if (b[j][i]) v[i][cnt][1 ] += a[j][i]; else cnt++,v[i][cnt][1 ] = v[i][cnt - 1 ][1 ] + a[j][i], v[i][cnt][0 ] = v[i][cnt - 1 ][1 ] + a[j][i]; } } for (int i = 1 ; i <= m; i++) for (int j = 0 ; j <= k; j++) for (int l = 0 ; l <= min(n,j); l++) { f[i][j][1 ] = max(f[i][j][1 ],f[i - 1 ][j - l][1 ] + v[i][l][1 ]); if (l) f[i][j][0 ] = max(f[i][j][0 ],f[i - 1 ][j - l][1 ] + v[i][l][0 ]); if (j > l) f[i][j][0 ] = max(f[i][j][0 ],f[i - 1 ][j - l][0 ] + v[i][l][1 ]); } printf ("%d\n" ,f[m][k][0 ]); return 0 ; }