@xunuo
2017-01-16T14:01:03.000000Z
字数 2355
阅读 1015
Time Limit: 9000/3000 MS (Java/Others) Memory Limit: 131072/65536 K (Java/Others)
链接:
HDU 3507 Print Article
斜率DP A-Print Article
斜率DP
Zero has an old printer that doesn't work well sometimes. As it is antique, he still like to use it to print articles. But it is too old to work for a long time and it will certainly wear and tear, so Zero use a cost to evaluate this degree.
One day Zero want to print an article which has N words, and each word i has a cost Ci to be printed. Also, Zero know that print k words in one line will cost
M is a const number.
Now Zero want to know the minimum cost in order to arrange the article perfectly.
There are many test cases. For each test case, There are two numbers N and M in the first line (0 ≤ n ≤ 500000, 0 ≤ M ≤ 1000). Then, there are N numbers in the next 2 to N + 1 lines. Input are terminated by EOF.
A single number, meaning the mininum cost to print the article.
5 5
5
9
5
7
5
230
题意:
第一行包括两个数n和m,m就是题上公式中的常数
接下来输入一个序列a[n];
输出序列a[n],每连续输出的费用是连续输出的数字和的平方加上常数M
让我们求这个费用的最小值。
解题思路:
啊啊啊啊啊!这还只是一个最简单的斜率DP。。。还只是入门题,,我觉得好难啊!!!
来我们来推一下哈!!
提前说明:这里举例 i、j、k 三个数,三个数的大小顺序是k<j<i !!!
首先,由它告诉我们的的计算公式就可以知道dp[i]=min(dp[j]+(sum[i]-sum[j])^2+m);
然后,为什么说是斜率呢?
假如我们设j优于k的话,就有:dp[j]+(sum[i]-sum[j])^2+m)<dp[k]+(sum[i]-sum[k])^2+m);然后去括号化简就有:
(dp[j]+sum[j]^2)-(dp[k]+sum[k]^2)/(2*(sum[j]-sump[k]))<=sum[i];如果设yj=dp[j]+sum[j]^2,yk=dp[k]+sum[k]^2,xj=2*sum[j],xk=2*sump[k];那就正好是(yj-yk)/(xj-xk)<=sum[i];------正好就是斜率
不要问why,,,这就像高中做数学题那样,,,谁知道你写着写着就出来一个特殊的东西呢?用数学老师的话就是:显然......
我们将斜率用g(j,k)表示,就有g(j,k)<=sum[i];
再有:如果g(i,j)<=g(j,k),则有j不可能为最优的,将j删去!这个我还是不明白要怎么证,
看图比较 清 楚 明了。。。。这个时候就非常想念赵伯伯
需要维护一个斜率递增(即下凸)图形,j点不符合,故将j点删了
啊啊啊啊啊!!!其实我并不懂为什么!!!好难啊!!
完整代码:
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
int dp[500010];
int sum[50010];
int a[500010];
int q[500010];
int num1(int j,int k)
{
int yj=dp[j]+sum[j]*sum[j];
int yk=dp[k]+sum[k]*sum[k];
return yj-yk;
}
int num2(int j,int k)
{
int xj=2*sum[j];
int xk=2*sum[k];
return xj-xk;
}
int main()
{
int n,m;
int head,tail;
while(scanf("%d%d",&n,&m)!=EOF)
{
memset(sum,0,sizeof(sum));
memset(a,0,sizeof(a));
memset(dp,0,sizeof(dp));
memset(q,0,sizeof(q));
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
sum[i]=sum[i-1]+a[i];
}
head=0,tail=1;
for(int i=1;i<=n;i++)
{
while(head+1<tail&&num1(q[head+1],q[head])<=sum[i]*(num2(q[head+1],q[head])))
head++;
dp[i]=dp[q[head]]+(sum[i]-sum[q[head]])*(sum[i]-sum[q[head]])+m;
while(head+1<tail&&num1(i,q[tail-1])*(num2(q[tail-1],q[tail-2]))<=num1(q[tail-1],q[tail-2])*num2(i,q[tail-1]))
tail--;
q[tail]=i;
tail++;
}
printf("%d\n",dp[n]);
}
return 0;
}