这是我参与更文挑战的第 15 天,活动详情查看: 更文挑战
状态压缩型 DP
状态压缩型 DP 的关键就是如何表达 f(x) 函数。这类题目有几个特点:
- x 往往表示一个数组,并且这个数组是原始数组 A[] 的一个子集:即 A[] 数组中的每个元素可以选,也可以不选;
- 直接用哈希或者数组表达 f(x) 都不是特别方便;
- 原始数组 A[] 并不会特别长。
基于这样的一个特点,我们在设计 f(x) 函数的时候,就需要根据以下两个点进行破题:
- A[i] 元素可以选,也可以不选;
- 原始数组 A[] 不会特别长。
选和不选,可以用 0/1 来表示,虽然 x 表示的是一个数组,但是我们可以用一个与原始数组等长(或者更长一些)的二进制整数 y 来表示第 i 个 bit 位:0 表示不选 A[i],1 表示选 A[i]。
这样,就可以通过一个整数 y 表达 x 数组,然后利用数组 dp[y] 表达 f(x) 的状态。下面我们看一下例题。这里用了一个非常简单的整数表示一个数组的信息,所以这种 DP 也被叫作状态压缩 DP。
例子:N 次操作的最大分数和
题目
给你 A[] ,它是一个大小为 2 * n 的正整数数组。你必须对这个数组执行 n 次操作。在第 i 次操作时(操作编号从 1 开始),你需要:
Step 1. 选择两个元素 x 和 y 。
Step 2. 获得分数 i * gcd(x, y) 。
Step 3. 将 x 和 y 从 A[] 中删除。
返回 n 次操作后,请你求解获得的分数和最大为多少。函数 gcd(x, y) 是 x 和 y 的最大公约数。
输入:A = [3,4,6,8]
输出:11
解释:最优操作是:
(1 * gcd(3, 6)) + (2 * gcd(4, 8)) = 3 + 8 = 11
注意:数组最大长度为 14。
分析
当拿到这道题的时候,我们首先进行一下简化,把这个问题转化为一个等价的问题。
虽然题目中给定的是删除操作,我们可以把这个删除之后的元素,放到一个数组 C[] 中,操作步骤如下:
(1 * gcd(3, 6)) + (2 * gcd(4, 8)) = 3 + 8 = 11
也可以认为是:
- 第 1 次添加:<3, 6> = 1 * gcd(3, 6)), C = [3, 6]
- 第 2 次添加:<4, 8> = (2 * gcd(4, 8), C = [3, 6, 4, 8]
- 最终收益 = 11
题目就可以简化成如下这样。
一开始你有一个空数组 C,以及有元素的数组 A[],你需要如下操作:
从 A 数组中选择两个数 x,y,然后将这两个数从 A[] 中删除;
将这两个数放到 C 中;
获得分数 len(C) / 2 * gcd(x, y)。
求如何操作,得到最大分数。
这样处理的好处在于:我们不再需要记录步数信息,只需要看当前数组 C 的大小,就可以得到当前是第几步。即步数 = C 数组大小 / 2。
接下来,我们就基于这个稍做改动的等价题目来分析。需要注意两个信息:
- 最大值
- 数组本身不会太大
我们首先分析第一个信息,求最大值,那么这里我们尝试一下 DP。而第二个信息告诉我们,如果进行暴力搜索,其实状态空间也只有 2N 种。并不算特别大。
下面我们看一下如何运用好题意给出的这两个重要信息。
1. 最后一步
当我们执行到最后一步的时候,数组中肯定只会剩下两个数了。假设这两个数是 <x, y>,那么最后一步,得到的收益就是:
last_step = len(C) + 2 / 2;
last_income = 形成数组 C[] 的最大收益 + last_step * gcd(x,y)。
复制代码
但是,最终余下的两个数 <x, y>,可以是原始数组 A[] 的任意的两个数。所以我们可以用伪代码表示如下:
last_income = 0;
for (int j = 0; j < N; j++): // x = a[j]
for (int k = j + 1; k < N; k++): // y = A[k]
// 数组C[], 再加上最后一步加入的<x,y>,那么长度必然与原始数组A一样长
last_step = len(C) + 2 / 2 即 len(A) / 2
last_income = max((计算C[]数组的收益 + last_step * gcd(A[j], A[k]),
last_income);
复制代码
2. 子问题
研究最后一步之后,可以发现,要递归计算的是数组 C[] 和数组 A[] 的最大收益,那么子问题可以表示如下:
- f(A) 表示数组 A[] 的最大收益;
- f(C) 表示数组 C[] 的最大收益。
我们可以统一用 f(x) 表示最终问题与子问题:
f(x) 表示 x[] 数组的最大收益;
其中,x[] 是原始数组 A[ ]的子序列。
3. 递推关系
我们可以利用伪代码,重新表达一下这个递推关系,代码如下(解析在注释里):
int f(x[]) { // 形成数组x[]的最大收益
ans = 0;
for (int j = 0; j < N; j++):
for (int k = j + 1; k < N; k++):
C[] = x.remove_index_item(j, k)
L = len(x) / 2; // C[]数组加入<x,y>之后形成x数组
ans = max((f(C) + L * gcd(A[j], A[k]), ans);
return ans;
}
复制代码
4. f(x) 的表达
由于 x 数组肯定是 A 数组的一个子集,并且 A 数组并不是特别大。那么我们可以用二进制串来表示 A[i] 元素是否存在于 x 数组这个关系:
- 1 表示 A[i] 存在于 x 数组中;
- 0 表示 A[i] 不存在于 x 数组中。
在这种情况下,我们可以申请一个数组:
int[] dp = new int[1<<(len(A))];
复制代码
然后用 dp[i] 表示 f(x)。其中i这个整数的二进制串表示:A[] 数组的子序列 x[]。
5. 初始条件与边界
首先,当数组为空的时候,肯定是没有什么收益的。所以此时 dp[0] = 0。并且,由于我们总是成对地添加元素,所以当 dp[i] 中的下标 i 里面的 bit 1 的个数为奇数(表示 x[] 数组有奇数个元素),这种情况应该是不可能出现的,不需要进行处理。
6. 计算顺序
当我们使用更改之后的题目进行处理的时候,就可以直接从 dp[0] 开始计算了。
完整代码
得到状态压缩之后,我们可以写出代码如下(解析在注释里):
class Solution{
private int bitCount(int x) {
int ret = 0;
while (x != 0) {
ret += (x & 0x01) == 1 ? 1 : 0;
x >>= 1;
}
return ret;
}
private int gcd(int a, int b) { return b == 0 ? a : gcd(b, a % b); }
public int maxScore(int[] A) {
final int N = A == null ? 0 : A.length;
final int total_steps = N >> 1;
// 一共有N个数
// 每个数可以表示存在,或者不存在
// 那么只有两种状态0/1
// 因此,我们可以用二进制位来进行表示
// 由于题目中已经说明n <= 7
// 所以,最多只需要14 bits
// 那么用一个int位,我们就可以表示了
// 所以这里我们申请dp[array_size]
final int array_size = 1 << N;
int[] dp = new int[array_size];
// dp[0] = 0
// 表示当没有任何数的时候,那么收益肯定为0
// 已经设置过了,这里不用再设置
// 那么接下来就是从余下两个数的时候开始
// 往前推导
for (int i = 3; i < array_size; i++) {
// 这里利用GCC内置的计算整的二进制中1的个数的函数
int cnt = bitCount(i);
// 由于每次需要去掉两个数,当i里面的二进制1的数目为
// 奇数的时候,没有必要计算!
if ((cnt & 0x01) == 1) {
continue;
}
// 当前步数
// 即: 当前我是第几步
final int cur_step = cnt >>> 1;
// 那么我们需要从i里面选两个数
for (int j = 0; j < N; j++) {
// 如果i中没有A[j]这个数
if ((i & (1 << j)) == 0)
continue;
for (int k = j + 1; k < N; k++) {
// 如果i中没有A[k]这个数
if ((i & (1 << k)) == 0)
continue;
// 这里我们选择A[j], A[k]
final int g = gcd(A[j], A[k]);
// 得分
final int score = cur_step * g;
// 得到去掉i,j之后的状态
final int mask = (1 << j) | (1 << k);
final int pre_status = i & ~mask;
final int total = dp[pre_status] + score;
// 选择最大值dp[i]
dp[i] = Math.max(dp[i], total);
}
}
}
return dp[array_size - 1];
}
}
复制代码
复杂度分析:当给定 N 个数的时候,一共需要表达 2N 种状态,而每种状态在处理时候,需要遍历 N x N 次,所以时间复杂度为 O(2N x N x N),空间复杂度为 O(2N)。虽然看起来很大,但是题目中已经明确说了数组的长度 <= 14(最多只有 7 对数)。
小结
这道 DP 题目的关键就是 f(x) 的表示。