上一篇博客讲了石子合并的基本做法,n^3复杂度的dp,今天无意间看到这个优化方法,觉得有必要学习一下。
平行四边形优化是一种可以将三维DP复杂度降到n^2方的方法,但是并不是所有的dp都适用,需要满足一定条件,如下:
当决策代价函数w[i][j]满足w[ i ][ j ]+w[ i’ ][ j’ ]<=w[ i; ][ j ]+w[ i ][ j’ ](i<=i’<=j<=j’)时,称w满足四边形不等式.
当函数w[ i ][ j ]满足w[ i’ ][ j ]<=w[ i ][ j’ ] i<=i’<=j<=j’)时,称w关于区间包含关系单调.
如果满足以上两点,可利用四边形不等式推出最优决策s的单调函数性,从而减少每个状态的状态数,将算法的时间复杂度降低一维。
具体的实施是通过记录子区间的最优决策来减少当前的决策量.令:
s[ i ][ j ]=max{k | ma[ i ] [ j ] = m[ i ][ k-1 ] + m[ k ] [ j ] + w[ i ][ j ] }
即s[ i ] [ j ]就记录了合并第i到第j堆石子时的最优合并,记录是为了限制后面的循环范围,如上所说。
证明如下(转载 http://www.cnblogs.com/jiu0821/p/4493497.html):
设m[i,j]表示动态规划的状态量。
m[i,j]有类似如下的状态转移方程:
m[i,j]=opt{m[i,k]+m[k,j]}(i≤k≤j)
如果对于任意的a≤b≤c≤d,有m[a,c]+m[b,d]≤m[a,d]+m[b,c],那么m[i,j]满足四边形不等式。
以上是适用这种优化方法的必要条件
对于一道具体的题目,我们首先要证明它满足这个条件,一般来说用数学归纳法证明,根据题目的不同而不同。
通常的动态规划的复杂度是O(n3),我们可以优化到O(n2)
设s[i,j]为m[i,j]的决策量,即m[i,j]=m[i,s[i,j]]+m[s[i,j]+j]
我们可以证明,s[i,j-1]≤s[i,j]≤s[i+1,j] (证明过程见下)
那么改变状态转移方程为:
m[i,j]=opt{m[i,k]+m[k,j]} (s[i,j-1]≤k≤s[i+1,j])
复杂度分析:不难看出,复杂度决定于s的值,以求m[i,i+L]为例,
(s[2,L+1]-s[1,L])+(s[3,L+2]-s[2,L+1])…+(s[n-L+1,n]-s[n-L,n-1])=s[n-L+1,n]-s[1,L]≤n
所以总复杂度是O(n2)
对s[i,j-1]≤s[i,j]≤s[i+1,j]的证明:
设mk[i,j]=m[i,k]+m[k,j],s[i,j]=d
对于任意k<d,有mk[i,j]≥md[i,j](这里以m[i,j]=min{m[i,k]+m[k,j]}为例,max的类似),接下来只要证明mk[i+1,j]≥md[i+1,j],那么只有当s[i+1,j]≥s[i,j]时才有可能有ms[i+1,j][i+1,j]≤md[i+1,j]
(mk[i+1,j]-md[i+1,j]) - (mk[i,j]-md[i,j])
=(mk[i+1,j]+md[i,j]) - (md[i+1,j]+mk[i,j])
=(m[i+1,k]+m[k,j]+m[i,d]+m[d,j]) - (m[i+1,d]+m[d,j]+m[i,k]+m[k,j])
=(m[i+1,k]+m[i,d]) - (m[i+1,d]+m[i,k])
∵m满足四边形不等式,∴对于i<i+1≤k<d有m[i+1,k]+m[i,d]≥m[i+1,d]+m[i,k]
∴(mk[i+1,j]-md[i+1,j])≥(mk[i,j]-md[i,j])≥0
∴s[i,j]≤s[i+1,j],同理可证s[i,j-1]≤s[i,j]
证毕
解决这类dp平行四边形优化问题的大概步骤是:
1.证明w满足四边形不等式,这里w是m的附属量,如m[i,j]=opt{m[i,k]+m[k,j]+w[i,j]},此时大多要先证明w满足条件才能进一步证明m满足条件
2.证明m满足四边形不等式
3.证明s[i,j-1]≤s[i,j]≤s[i+1,j]
更新后的代码如下:
#include#include #include #include using namespace std; int n,x; int sum[205]; int dp[205][205]; int s[205][205]; int main() { while(~scanf("%d",&n)) { sum[0]=0; memset(dp ,0,sizeof dp); for(int i=1;i<=n;i++) { scanf("%d",&x); sum[i]=sum[i-1]+x; dp[i][i]=0; s[i][i]=i; } for(int len=2;len<=n;len++) for(int i=1;i<=n;i++) { int j=i+len-1; if(j>n) continue; for(int k=s[i][j-1];k<=s[i+1][j];k++) { if(dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1]