6/25, 动态规划,区间类DP
求一段区间的解 min/max/count
相比划分类 DP ,区间类 DP 为连续相连的 subproblem ,中间不留空,更有 divide & conquer 的味道。
转移方程通过区间更新
从大到小的更新
Matrix-chain multiplication (算法导论)
给定矩阵向量 [A1, A2, A3 .. An]
矩阵乘法有结合律,所以任意的 parenthesization 结果一样
Dimension (a x b) 乘 (b x c) 得到 (a x c) ,总计算量为 a x b x c
可能的括号加法为 Catalan 数,O(2^n),因而搜索不合适。
让 dp[i][j] 代表(i , j) 区间内最优的括号顺序运算次数
符合 optimal substructure,反证法
A = rows * cols,假如从 k 分开左右 i <= k < j ,如下 k = 5 时:
[A1, A2, A3, A4, A5 || A6, A7,A8,A9]
左子问题为 A1.rows x A5.cols
右子问题为 A6.rows x A9.cols
其中 A5.cols = A6.rows
其总花费为 dp[1,5] + dp[6,9] + A1.rows * A5.cols * A9.cols
至此,对于任意 size (i , j) 的向量区间,我们都可以遍历所有合理 k 的切分点,实现记忆化的 divide & conquer,当前区间的最优解一定由其最优子区间拼接而成。
子问题图如下。其实就是一个 n x n 的矩阵对角线,代表所有的子区间。
Palindrome Partitioning II
上一题的求所有区间最优解进行拼接的思路和 optimal substructure 结构和这题非常像,再贴一遍,感受一下。
不过 Matrix Chain Multiplication 要比这个复杂,时间复杂度为 O(n^3). 毕竟每个切点上会生成两个 subproblems.
先贴一段O(n^3)的代码. TLE
犯了错, int tmp = s[i]==s[j]&&dp[i+1][j-1]==0 ? 0:2;
"cabababcbc"在‘c’==‘c’ “abababcb”只需要一个cut的时候,认为"cabababcbc"只要一个cut
忽略了palindrome的连续性, s1一个cut, 并不代表c+s1+c是一个cut
public class Solution {
public int minCut(String str) {
char[] s = str.toCharArray();
int[][] dp = new int[s.length+1][s.length+1];
for(int i=0;i<=s.length;i++){
dp[i][i] = 0;
if(i+1<s.length && s[i]==s[i+1] )
dp[i][i+1]=0;
else if(i+1<s.length)
dp[i][i+1]=1;
}
for(int l=3;l<=s.length;l++){
for(int i=0;i+l-1<s.length;i++){
int j = i+l-1;
int tmp = s[i]==s[j]&&dp[i+1][j-1]==0 ? 0:2;
/*
犯了错, int tmp = s[i]==s[j]&&dp[i+1][j-1]==0 ? 0:2;
"cabababcbc"在‘c’==‘c’ “abababcb”只需要一个cut的时候,认为"cabababcbc"只要一个cut
忽略了palindrome的连续性
*/
dp[i][j] = tmp + dp[i+1][j-1];
//if(i==0&&j==s.length-1) System.out.println("init: " + tmp + " " + dp[i+1][j-1] +" "+ dp[i][j]);
for(int k=i;k<j;k++){
//if(i==0&&j==s.length-1) System.out.println(k + " " + dp[i][k] + " " + dp[k+1][j] + " " + (dp[i][k]+dp[k+1][j]+1));
dp[i][j] = Math.min(dp[i][j], dp[i][k]+dp[k+1][j]+1);
}
}
}
return dp[0][s.length-1];
}
}
O(n^2), O(n) space 的代码
public class Solution {
public int minCut(String s) {
if(s == null || s.length() <= 1) return 0;
int len = s.length();
boolean[][] isPalindrome = new boolean[len][len];
int[] dp = new int[len];
for(int i = 0; i < len; i++){
dp[i] = i;
for(int j = 0; j <= i; j++){
if(s.charAt(i) == s.charAt(j) && (i - j < 2 || isPalindrome[j + 1][i - 1])){
isPalindrome[i][j] = isPalindrome[j][i] = true;
if(j == 0){
dp[i] = 0;
} else {
dp[i] = Math.min(dp[i], dp[j - 1] + 1);
}
}
}
}
return dp[len - 1];
}
}
Stone Game
著名的区间类 DP 入门题 -- 石子归并
以数组【3,4,5,6】 为例,进行归并的 subproblem graph 如下,path 上的数字代表每一步的 cost.
我们一定可以得到一个 height balanced complete tree,因为每步都只归并两堆石子,只是每步的 branching factor 不同;
所有 subproblem 的叶节点 cost 一致,为所有石子的总和。
这种画法的结构是对的,但是并不合理,因为没有体现出“overlap subproblems”,每个子问题看起来都像独立问题一样。
这样就明显多了,而且和前面的 “区间划分DP” 联系紧密。
自己用记忆化搜索写的第一版,比较粗糙~
每次归并的 cost = 归并两个区间的最优 cost + 两个区间的区间和
因此区间最优用 dp[][] 记忆化搜素,区间和 sum[][] 可以 O(n ^ 2) 时间预处理。
O(n^2) preprocess + O(n^2) number of intervals * O(n) number of candidate cuts = O(n^2) + O(n^3)
可以看到,记忆化搜索中,dp[][] 每一个位置只会被遍历一次而且不会再生成新的 subproblems,其时间复杂度和 bottom-up 的迭代循环是一样的。
public class Solution {
/**
* @param A an integer array
* @return an integer
*/
public int stoneGame(int[] A) {
// Write your code here
if(A == null || A.length == 0) return 0;
int n = A.length;
// Minimum cost to merge interval dp[i][j]
int[][] dp = new int[n][n];
int[][] sum = new int[n][n];
// Pre-process interval sum
for(int i = 0; i < n; i++){
for(int j = i; j >= 0; j--){
if(j == i) sum[i][j] = A[i];
else sum[i][j] = sum[j][i] = A[j] + sum[j + 1][i];
}
}
return memoizedSearch(0, n - 1, A, dp, sum);
}
private int memoizedSearch(int start, int end, int[] A, int[][] dp, int[][] sum){
if(start > end) return 0;
if(start == end) return 0;
if(start + 1 == end) return A[start] + A[end];
if(dp[start][end] != 0) return dp[start][end];
int min = Integer.MAX_VALUE;
for(int i = start; i < end; i++){
int cost = memoizedSearch(start, i, A, dp, sum) + memoizedSearch(i + 1, end, A, dp, sum) + sum[start][i] + sum[i + 1][end];
min = Math.min(min, cost);
}
dp[start][end] = min;
return min;
}
}
对于 interval sum ,根据搜索结构可以做一个显而易见的优化,因为每次 split 的 start, pivot, end 我们都知道,而且合并(start, end) 区间的两堆石子,最终的区间和一定为 (start, end) 的区间和,用一维的 prefix sum 数组就可以了。
用 prefix sum 数组要记得初始化时候的 int[n + 1] zero padding,还有取值时候对应的 sum[end + 1] - sum[start + 1 - 1] offset.
public class Solution {
/**
* @param A an integer array
* @return an integer
*/
public int stoneGame(int[] A) {
// Write your code here
if(A == null || A.length == 0) return 0;
int n = A.length;
// Minimum cost to merge interval dp[i][j]
int[][] dp = new int[n][n];
int[] sum = new int[n + 1];
// Pre-process interval sum
for(int i = 0; i < n; i++){
sum[i + 1] = sum[i] + A[i];
}
return memoizedSearch(0, n - 1, A, dp, sum);
}
private int memoizedSearch(int start, int end, int[] A, int[][] dp, int[] sum){
if(start > end) return 0;
if(start == end) return 0;
if(start + 1 == end) return A[start] + A[end];
if(dp[start][end] != 0) return dp[start][end];
int min = Integer.MAX_VALUE;
for(int i = start; i < end; i++){
int cost = memoizedSearch(start, i, A, dp, sum) + memoizedSearch(i + 1, end, A, dp, sum) + sum[end + 1] - sum[start];
min = Math.min(min, cost);
}
dp[start][end] = min;
return min;
}
}
Burst Balloons
这题和石子归并很像,更像 Matrix Chain Multiplication. 都是区间类 DP,而且原数组会随着操作逐渐减小,动态变化。+
然而就算是动态变化的数组,变化的也并不是状态,而只是子状态的范围,记忆化搜索中的 (start, end).
所以这题的难点在于,如何在动态变化的数组中,依然正确定义并计算 subproblem.
问题一:边界气球
- 考虑到计算方式为相邻气球乘积,可以两边放上 1 来做 padding,不会影响最后结果的正确性。
问题二:子问题返回后,如何处理相邻气球?
- 在stone game中,最后融合两个区间要靠区间和;
- 在busrt balloon中,两个区间返回时已经都被爆掉了,融合区间靠的是两个区间最外面相邻的气球。(因此 padding 才很重要)
- 正如 Matrix Chain Multiplication 中,左右区间相乘结束返回时,最后融合那步的 cost = A(start).rows * A(k).cols * A(end).cols
在这里用类似Matrix Chain Multiplication中枚举的方式,长度在最外层枚举,right = left+k; dp(left,right)=dp(left, i)+dp(i+1, right) + nums[left]*nums[i]*nums[right].
public class Solution {
public int maxCoins(int[] iNums) {
int[] nums = new int[iNums.length + 2];
int n = 1;
for (int x : iNums) if (x > 0) nums[n++] = x;
nums[0] = nums[n++] = 1;
int[][] dp = new int[n][n];
for(int k=2;k<n;k++)
for(int left=0;left<n-k;left++){
int right = left+k;
for(int i=left+1;i<right;i++)
dp[left][right] = Math.max(dp[left][right], dp[left][i] + dp[i][right] + nums[left]*nums[i]*nums[right]);
}
return dp[0][n-1];
}
}
public class Solution {
public int maxCoins(int[] nums) {
if(nums == null || nums.length == 0) return 0;
int n = nums.length;
int[] arr = new int[n + 2];
arr[0] = 1;
arr[n + 1] = 1;
for(int i = 0; i < n; i++){
arr[i + 1] = nums[i];
}
int[][] dp = new int[n + 2][n + 2];
return memoizedSearch(1, n, arr, dp);
}
private int memoizedSearch(int start, int end, int[] arr, int[][] dp){
if(dp[start][end] != 0) return dp[start][end];
int max = 0;
for(int i = start; i <= end; i++){
int cur = arr[start - 1] * arr[i] * arr[end + 1];
int left = memoizedSearch(start, i - 1, arr, dp);
int right = memoizedSearch(i + 1, end, arr, dp);
max = Math.max(max, cur + left + right);
}
dp[start][end] = max;
return max;
}
}
这道题开始自己做思路是对的, 递归搜索左右两边subproblem, 问题很像matrix chain multiplication, 因为数组是动态变化的, 所以融合时要考虑两边start-1和end+1, 而不是i-1, i+1; 最后通过记忆化搜索优化.
注意start==end时,不能直接返回num[start], 还是需要相乘两边的值
31/70后TLE
public int maxCoins(int[] nums) {
if(nums==null || nums.length==0) return 0;
int max = dc(nums, 0, nums.length-1);
return max;
}
public int dc(int[] nums, int start, int end){
int max = 0;
for(int i=start;i<=end;i++){
int left = dc(nums, start, i-1);
int right = dc(nums, i+1, end);
int leftnum = start-1>=0?nums[start-1]:1;
int rightnum = end+1<nums.length?nums[end+1]:1;
int sum = left+right+leftnum*nums[i]*rightnum;
max = Math.max(sum, max);
}
return max;
}
memorization
public class Solution {
public int maxCoins(int[] nums) {
if(nums==null || nums.length==0) return 0;
int[][] dp = new int[nums.length][nums.length];
int max = dc(nums, 0, nums.length-1, dp);
return max;
}
public int dc(int[] nums, int start, int end, int[][] dp){
if(start>=0 && end<nums.length && end>=0 && start<nums.length && dp[start][end]!=0) return dp[start][end];
int max = 0;
for(int i=start;i<=end;i++){
int left = dc(nums, start, i-1, dp);
int right = dc(nums, i+1, end, dp);
int leftnum = start-1>=0?nums[start-1]:1;
int rightnum = end+1<nums.length?nums[end+1]:1;
int sum = left+right+leftnum*nums[i]*rightnum;
max = Math.max(sum, max);
}
if(start>=0 && end<nums.length && end>=0 && start<nums.length ) dp[start][end] = max;
return max;
}
}
Remove Boxes
这个题和burst ballon很像,一开始采用了和burst ballon一样的方式,转移方程为dp(left, right) = max( (i-left)*(i-left)*dp(i, right), (right-j)(right-j)*dp(j+1, right). i和j为相同连续boxes的位置。这个思路忽略了一个细节,i只merge了相邻的left box,也许存在先merge不相邻的,使得后面有间隔的left box没间隔,再统一merge.
论坛里构建的是一个三维dp,dp(i, j, k), i代表起始,j代表结束,k代表有多少个和i相同的box在i的左边。这样dp方程为res = (k+1)(k+1) + dp(i+1,j,0); if(box[i]==box[m]) 时,前k个box和第i个box也可以和第m个box结合,第m个box再决定怎么用这k个box。res = max (res, dp[i+1, m-1, 0] + dp(m, j, k+1))
dp(i, i-1, k) = 0
dp(i, i, k) = (k+1)*(k+1)
remove box[i]; dp(i,j,k) = (k+1)*(k+1)+T(i+1, j, 0)
attach box[i] to box[m] who has same color, dp (i, j, k) = dp[i+1, m-1, 0] + dp(m, j, k+1)
Top-down + memorization
public class Solution {
public int removeBoxes(int[] boxes) {
if(boxes==null || boxes.length==0) return 0;
int n = boxes.length;
int[][][] dp = new int[n][n][n];
return getMax(boxes, 0, n-1, 0, dp);
}
public int getMax(int[] boxes, int start, int end, int k, int[][][] dp){
if(start>end) return 0;
if(start==end) return (k+1)*(k+1);
if(dp[start][end][k]>0) return dp[start][end][k];
int res = (k+1)*(k+1)+getMax(boxes, start+1, end, 0, dp);
for(int m=start+1; m<=end;m++){
if(boxes[start] == boxes[m])
res = Math.max(res, getMax(boxes, start+1, m-1, 0, dp)+getMax(boxes, m,end, k+1, dp));
}
dp[start][end][k] = res;
return res;
}
}
bottom-up
public int removeBoxes(int[] boxes) {
int n = boxes.length;
int[][][] dp = new int[n][n][n];
for (int j = 0; j < n; j++) {
for (int k = 0; k <= j; k++) {
dp[j][j][k] = (k + 1) * (k + 1);
}
}
for (int l = 1; l < n; l++) {
for (int j = l; j < n; j++) {
int i = j - l;
for (int k = 0; k <= i; k++) {
int res = (k + 1) * (k + 1) + dp[i + 1][j][0];
for (int m = i + 1; m <= j; m++) {
if (boxes[m] == boxes[i]) {
res = Math.max(res, dp[i + 1][m - 1][0] + dp[m][j][k + 1]);
}
}
dp[i][j][k] = res;
}
}
}
return (n == 0 ? 0 : dp[0][n - 1][0]);
}
Scramble String
弄了半天写了个错误的版本,只考虑了 cut 位置对齐的情况,可以过 157 / 281 个 test cases, 然而像 "abc" 和 "bca" 这种起始位置就不对齐的就会出错。
a | bc
bc | a
所以很显然的,O(n^3) 泡汤了~
http://www.blogjava.net/sandy/archive/2013/05/22/399605.html
下面的是基于九章答案的记忆化搜素解法,改了我好久。。。
改写过程中一直在犯的错误是,在 subcall 中 s1,s2 已经是 substring 的情况下,依然用上一层传过来的参数作为参考去切分新的 substring. 这是错误的,只需要在参数中得到的 s1, s2 上切割就好了,因为传进来的并不是最原始的 string.
每一层 search 中,参数里面的 start / end / n 代表着相对于原始 string 的位置,用于查询和记录 DP; 而这一层的 s1, s2 又是新的子问题,除了涉及传参和DP之外的地方,都以 s1, s2 为准。
s.substring(i,j) 中,最后截取的 substring 长度就是 j - i.
public class Solution {
public boolean isScramble(String s1, String s2) {
if(!isAnagram(s1, s2)) return false;
int n = s1.length();
// dp[i][j][k] : s1 starting from index i, s2 string from index j
// pick k chars, are we getting scrambled strings ?
// 0 : not searched, 1 : true, -1 : false;
int[][][] dp = new int[n][n][n + 1];
return isScrambleMemo(s1, s2, 0, 0, n, dp);
}
private boolean isScrambleMemo(String s1, String s2, int oneStart, int twoStart, int n, int[][][] dp){
if(dp[oneStart][twoStart][n] != 0) return (dp[oneStart][twoStart][n] == 1) ? true : false;
if(s1.equals(s2)){
dp[oneStart][twoStart][n] = 1;
return true;
}
if(!isAnagram(s1, s2)){
dp[oneStart][twoStart][n] = -1;
return false;
}
// i = number of characters we take
for(int i = 1; i < s1.length() ; i++){
String s1Left = s1.substring(0, i);
String s1Right = s1.substring(i, s1.length());
String leftSideS2Left = s2.substring(0, i);
String leftSideS2Right = s2.substring(i, s2.length());
String rightSideS2Left = s2.substring(0, s2.length() - i);
String rightSideS2Right = s2.substring(s2.length() - i, s2.length());
if(isScrambleMemo(s1Left, leftSideS2Left, oneStart, twoStart, i, dp) &&
isScrambleMemo(s1Right, leftSideS2Right, oneStart + i, twoStart + i, n - i, dp)) {
dp[oneStart][twoStart][n] = 1;
return true;
}
if(isScrambleMemo(s1Left, rightSideS2Right, oneStart, twoStart + n - i, i, dp) &&
isScrambleMemo(s1Right, rightSideS2Left, oneStart + i, twoStart, n - i, dp)) {
dp[oneStart][twoStart][n] = 1;
return true;
}
}
dp[oneStart][twoStart][n] = -1;
return false;
}
// Assuming only lower case letters
private boolean isAnagram(String s1, String s2){
if(s1.length() != s2.length()) return false;
int[] hash = new int[26];
for(int i = 0; i < s1.length(); i++){
int index = s1.charAt(i) - 'a';
hash[index] ++;
}
for(int i = 0; i < s2.length(); i++){
int index = s2.charAt(i) - 'a';
hash[index] --;
if(hash[index] < 0) return false;
}
return true;
}
}
自己写的, 犯了string.substring(0,i) i=1处枚举, 和上面"abc" "bca"的错误. 过了275/282个case后TLE,
public class Solution {
HashMap<String, Boolean> map = new HashMap<String, Boolean>();
public boolean isScramble(String s1, String s2) {
if(s1==null && s2==null) return true;
if(s1==null || s2==null) return false;
if(s1.length()!=s2.length()) return false;
if(s1.length()==1 && !s1.equals(s2)) return false;
if(s1.equals(s2)||new StringBuilder(s1).reverse().toString().equals(s2)) return true;
String s1reverse = new StringBuilder(s1).reverse().toString();
String s2reverse = new StringBuilder(s2).reverse().toString();
if(map.containsKey(s1+"\n"+s2) ) return map.get(s1+"\n"+s2);
if(map.containsKey(s1reverse+"\n"+s2) ) return map.get(s1reverse+"\n"+s2);
if(map.containsKey(s1+"\n"+s2reverse) ) return map.get(s1+"\n"+s2reverse);
if(map.containsKey(s1reverse+"\n"+s2reverse) ) return map.get(s1reverse+"\n"+s2reverse);
for(int i=1;i<s1.length();i++)
for(int j=1;j<s2.length();j++)
{
String s1left = s1.substring(0,i);
String s2left = s2.substring(0,j);
String s1right = s1.substring(i);
String s2right = s2.substring(j);
//System.out.println(s1left+" " +s2left+" " +s1right+" " +s2right );
if(isScramble(s1left, s2left) && isScramble(s1right, s2right)){
map.put(s1+"\n"+s2, true);
return true;
}
if(isScramble(s1left, s2right) && isScramble(s1right, s2left)){
map.put(s1+"\n"+s2, true);
return true;
}
}
map.put(s1+"\n"+s2, false);
return false;
}
}
论坛代码, 加了一个for循环判断是否是anagram, 非常好的想法.
for (int i=0; i<26; i++) if (letters[i]!=0) return false;
public class Solution {
public boolean isScramble(String s1, String s2) {
if (s1.equals(s2)) return true;
int[] letters = new int[26];
for (int i=0; i<s1.length(); i++) {
letters[s1.charAt(i)-'a']++;
letters[s2.charAt(i)-'a']--;
}
for (int i=0; i<26; i++) if (letters[i]!=0) return false;
for (int i=1; i<s1.length(); i++) {
if (isScramble(s1.substring(0,i), s2.substring(0,i))
&& isScramble(s1.substring(i), s2.substring(i))) return true;
if (isScramble(s1.substring(0,i), s2.substring(s2.length()-i))
&& isScramble(s1.substring(i), s2.substring(0,s2.length()-i))) return true;
}
return false;
}
}
Subarray Sum Equals K
水题 presum+hashmap;注意这里不能用two pointers因为数组中可以有负数。
public class Solution {
public int subarraySum(int[] nums, int k) {
if(nums==null || nums.length==0) return 0;
TreeMap<Integer, Integer> map = new TreeMap<Integer, Integer>();
int sum = 0; int count = 0; map.put(0, 1);
for(int i=0;i<nums.length;i++){
sum+=nums[i];
if(map.containsKey(sum-k))
count+=map.get(sum-k);
map.put(sum, map.getOrDefault(sum,0)+1);
}
return count;
}
}
Brick Wall
这题有意思,扫一下就知道是按层数来的DP。算穿过砖头数的时候如果是穿过砖头的起始点或末尾点,不算砖头块。dp[i][j]代表第i层j号格子穿过的砖头数。dp[i][j] = dp[i-1][j]当j是砖头开始或结束,否则dp[i][j]=dp[i-1][j]+1.
public class Solution {
public int leastBricks(List<List<Integer>> wall) {
int total = 0;
for(Integer i : wall.get(0)){
total+=i;
}
int[][] dp = new int[2][total+1];
/*
for(int i=0;i<=total; i++) dp[1][i] = 1; int leak = 0;
for(Integer i : wall.get(0)){
dp[1][leak] = 0;
leak+=i;
}
*/
for(int i=0;i<wall.size();i++){
int prev = 0;
for(int j=0;j<wall.get(i).size();j++){
dp[i%2][prev] = dp[(i+1)%2][prev];
for(int k=prev+1;k<prev+wall.get(i).get(j);k++)
dp[i%2][k] = dp[(i+1)%2][k]+1;
prev+=wall.get(i).get(j);
}
System.out.println(Arrays.toString(dp[i%2]));
}
int min = wall.size();
for(int i=1;i<total;i++)
min = Math.min(min, dp[(wall.size()-1)%2][i]);
return min;
}
}
这个情况挂在了[[10000], [10000], [10000]]上面。稀疏但是数字大。虽然用了滚动数组但是依然MLE。这个时候HashMap起到了作用,实际上hashmap只要统计每个“缝隙”处就可以了。最后用总的层数减去最大缝隙值。这个时候看着是dp,实际又不是dp。注意layer的最后一个sum不用算。
public class Solution {
public int leastBricks(List<List<Integer>> wall) {
HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
for(List<Integer> layer : wall){
int sum =0;
for(int i=0;i<layer.size()-1;i++){
sum+=layer.get(i);
map.put(sum, map.getOrDefault(sum, 0)+1);
}
}
int min = wall.size();
for(Integer i : map.values()){
min = Math.min(min, wall.size()-i);
}
return min;
}
}