@397915842
2014-10-15T12:02:02.000000Z
字数 3148
阅读 1896
题解
数学
http://acdream.info/problem?pid=1174
给出N, a[1]... a[N],还有M, b[1]... b[M], (1 <= N,M <= 50000), (1 <= a[i],b[i] <= 10000)
#include<cmath>
#include<iostream>
using namespace std;
int N, M;
long long ans;
int main()
{
while(cin >> N >> M)
{
ans = 0;
for(int i = 1; i <= N; ++ i)
for(int j = 1; j <= M; ++ j)
ans += (i - j) * abs(a[i] - b[j]);
cout << ans << endl;
}
return 0;
}
直接暴力妥妥超时,刚开始第一眼看到这题的时候,觉得可能是可以转化成一个数学问题,然后就可以很轻易地算出来,后来一直没有找到合适的方法,只好开始暴力了。
首先先假设b[j] <= a[i]对于所有的 i 和 j 都成立,那么这题瞬间就很简单了。
c\r | 1 | 2 | ... | M |
---|---|---|---|---|
1 | (1 - 1) * (a[1] - b[1]) | (1 - 2) * (a[1] - b[2]) | ... | (1 - M) * (a[1] - b[M]) |
2 | (2 - 1) * (a[2] - b[1]) | (2 - 2) * (a[2] - b[2]) | ... | (2 - M) * (a[2] - b[M]) |
... | ... | ... | ... | ... |
N | (N - 1) * (a[N] - b[1]) | (N - 2) * (a[N] - b[2]) | ... | (N - M) * (a[N] - b[M]) |
考虑每行,可以算出 a[i] 的个数有 M * (1 - M) / 2 + (i - 1) * M 个;
考虑每列,可以算出 b[j] 的个数有 N * (1 - N) / 2 + (j - 1) * N 个。
但是这样得出来的结果明显是错的,好在这错误是可以修正的,接下来我们考虑怎么修正。
刚开始的时候我是这样想的,直接对 b 数组进行排序,那么对于每个 a[i],我们可以直接用 upper_bounds 算出 b 数组中有多少个数比 a[i] 小,这时候考虑每行,那么就可以知道有多少个 a[i] - b[j] 是计算错误的,然后先把 a[i] 修正回来;接下来对 a 数组进行排序,那么对于每个 b[j],同样用 upper_bounds 算出 a 数组中有多少个数比 b[j] 小,这时候考虑每列,那么就可以知道有多少个 a[i] - b[j] 是计算错误的,然后先把 b[j] 修正回来。
然后,就果断 TLE 了,后来想了想,这边超时主要还是因为 upper_bounds,这个方法是没问题的, nlogn 的算法果断是可以的,这时候就需要对这里计算个数的方法进行优化。
想了想,完全可以不需要 upper_bounds,直接用 O(N + M) 的预处理就能立马算出想要的个数。
接下来大概讲解一下修正的过程,首先对 a 和 b 进行排序,我们先对 a 进行修正,假设
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long LL;
const int MAXN = 5e5 + 5;
class Node
{
public:
LL ind, val;
};
LL N, M, ans;
Node a[MAXN], b[MAXN];
LL sum_ind_a, sum_ind_b;
bool cmp(const Node &x, const Node &y)
{
return x.val < y.val;
}
int main()
{
cin.sync_with_stdio(false);
while(cin >> N >> M)
{
ans = 0;
for(LL i = 1, cnt = M * (1 - M) / 2; i <= N; ++ i, cnt += M)
{
cin >> a[i].val;
a[i].ind = i;
ans += a[i].val * cnt;
}
for(LL j = 1, cnt = N * (1 - N) / 2; j <= M; ++ j, cnt += N)
{
cin >> b[j].val;
b[j].ind = j;
ans += b[j].val * cnt;
}
sort(a + 1, a + N + 1, cmp);
sort(b + 1, b + M + 1, cmp);
sum_ind_b = M * (M + 1) / 2;
for(LL i = 1, lim = 1; i <= N; ++ i)
{
while(lim <= M && a[i].val >= b[lim].val)
sum_ind_b -= b[lim ++].ind;
ans -= 2 * a[i].val * (a[i].ind * (M - lim + 1) - sum_ind_b);
}
sum_ind_a = 0;
for(LL j = 1, lim = 1; j <= M; ++ j)
{
while(lim <= N && a[lim].val < b[j].val)
sum_ind_a += a[lim ++].ind;
ans += 2 * b[j].val * (sum_ind_a - (lim - 1) * b[j].ind);
}
cout << ans << endl;
}
return 0;
}