[关闭]
@397915842 2014-10-15T12:02:02.000000Z 字数 3148 阅读 1871

Acdream 1174 Sum

题解 数学


题目链接:

http://acdream.info/problem?pid=1174

题目大意:

给出N, a[1]... a[N],还有M, b[1]... b[M], (1 <= N,M <= 50000), (1 <= a[i],b[i] <= 10000)

  1. #include<cmath>
  2. #include<iostream>
  3. using namespace std;
  4. int N, M;
  5. long long ans;
  6. int main()
  7. {
  8. while(cin >> N >> M)
  9. {
  10. ans = 0;
  11. for(int i = 1; i <= N; ++ i)
  12. for(int j = 1; j <= M; ++ j)
  13. ans += (i - j) * abs(a[i] - b[j]);
  14. cout << ans << endl;
  15. }
  16. return 0;
  17. }

具体题解:

直接暴力妥妥超时,刚开始第一眼看到这题的时候,觉得可能是可以转化成一个数学问题,然后就可以很轻易地算出来,后来一直没有找到合适的方法,只好开始暴力了。
首先先假设b[j] <= a[i]对于所有的 ij 都成立,那么这题瞬间就很简单了。

c\r 1 2 ... M
1 (1 - 1) * (a[1] - b[1]) (1 - 2) * (a[1] - b[2]) ... (1 - M) * (a[1] - b[M])
2 (2 - 1) * (a[2] - b[1]) (2 - 2) * (a[2] - b[2]) ... (2 - M) * (a[2] - b[M])
... ... ... ... ...
N (N - 1) * (a[N] - b[1]) (N - 2) * (a[N] - b[2]) ... (N - M) * (a[N] - b[M])

考虑每行,可以算出 a[i] 的个数有 M * (1 - M) / 2 + (i - 1) * M 个;
考虑每列,可以算出 b[j] 的个数有 N * (1 - N) / 2 + (j - 1) * N 个。

但是这样得出来的结果明显是错的,好在这错误是可以修正的,接下来我们考虑怎么修正。
刚开始的时候我是这样想的,直接对 b 数组进行排序,那么对于每个 a[i],我们可以直接用 upper_bounds 算出 b 数组中有多少个数比 a[i] 小,这时候考虑每行,那么就可以知道有多少个 a[i] - b[j] 是计算错误的,然后先把 a[i] 修正回来;接下来对 a 数组进行排序,那么对于每个 b[j],同样用 upper_bounds 算出 a 数组中有多少个数比 b[j] 小,这时候考虑每列,那么就可以知道有多少个 a[i] - b[j] 是计算错误的,然后先把 b[j] 修正回来。

然后,就果断 TLE 了,后来想了想,这边超时主要还是因为 upper_bounds,这个方法是没问题的, nlogn 的算法果断是可以的,这时候就需要对这里计算个数的方法进行优化。

想了想,完全可以不需要 upper_bounds,直接用 O(N + M) 的预处理就能立马算出想要的个数。

接下来大概讲解一下修正的过程,首先对 ab 进行排序,我们先对 a 进行修正,假设

b[1]<=b[2]<=...<=b[x]<=a[i]<=b[x+1]...<=b[M]
那么明显只有 (a[i] - b[1]) ... (a[i] - b[x]) 计算正确,其它都是计算错误的,那么我们应该修正的就是 (a[i] - b[x + 1]) ... (a[i] - b[M]) 这些项了。如果直接对于每个 a[i] 都要走遍历一遍 b 数组,那么妥妥超时的,这里注意到,
a[1]<=a[2]<=...<=a[N]
那么,我们至少可以保证 a[i + 1] >= b[x] 的,这时候 a[i + 1] 只需要从 b[x + 1] 开始考虑即可,这样就是 O(N + M) 的算法了,这样就不会超时了。注意到对于每一行,b 数组前面减掉的下标的和均为 M * (M + 1) / 2,那么用一个 sum_ind_b 来记录 b 的下标和,每次发现有一个有一个 b[j] <= a[i] 就减去 b[j].index,那么这里 sum_ind_b = M * (M + 1) / 2 - b[1].index - ... - b[x].index。接下来算 a[i] 前的 i 的总和,假设有 kb[j] 小于等于 a[i],那么和明显是 k * a[i].index,然后只需要算 2 * (k * a[i].index - sum_ind_b) 即可算出 a[i] 多加上了多少个,直接减去即可。
修正 b 与修正 a 同理,考虑每列,假设
a[1]<=a[2]<=...<=a[x]<=b[j]<=a[x+1]<=...<=a[N]
那么,只有 (a[1] - b[j]) ... (a[x] - b[j]) 是计算错误的,后面都是计算正确的,那么只需要修正 (a[1] - b[j]) ... (a[x] - b[j]) 即可。这里用 sum_ind_a 记录 a 数组下标的和, 那么这里的 sum_ind_a = a[1].index + ... + a[x].index,然后 j 的总和是 x * b[j],那么这里 b[j] 多加上的个数为 2 * (sum_ind_a - x * b[j].index),直接加上即可。
虽然说思想上没问题了,但是却 WA 了好多发,检查了好多遍程序,后来不经意间发现,原来是相等的情况没有考虑好。刚开始修正 a 的时候,无视了 a[i] == b[j] 的情况,后面修正 b 的时候,代码没写好,考虑了 a[i] == b[j] 的情况,然后两边没有抵消掉, WA 得我差点哭了。
后来想了想,刚开始的时候完全没必要直接假设所有的 a[i] >= b[j],直接对 ab 进行排序,然后找到 a[i] - b[j] 符号改变的位置,然后将每一行和每一列分成两段来计算,这样程序就可以更快了。
后来突然注意到其实没必要排序的,注意到 1 <= a[i],b[i] <= 10000,这样直接可以用计数排序,那么代码就可以更快了。

代码:

  1. #include<iostream>
  2. #include<algorithm>
  3. using namespace std;
  4. typedef long long LL;
  5. const int MAXN = 5e5 + 5;
  6. class Node
  7. {
  8. public:
  9. LL ind, val;
  10. };
  11. LL N, M, ans;
  12. Node a[MAXN], b[MAXN];
  13. LL sum_ind_a, sum_ind_b;
  14. bool cmp(const Node &x, const Node &y)
  15. {
  16. return x.val < y.val;
  17. }
  18. int main()
  19. {
  20. cin.sync_with_stdio(false);
  21. while(cin >> N >> M)
  22. {
  23. ans = 0;
  24. for(LL i = 1, cnt = M * (1 - M) / 2; i <= N; ++ i, cnt += M)
  25. {
  26. cin >> a[i].val;
  27. a[i].ind = i;
  28. ans += a[i].val * cnt;
  29. }
  30. for(LL j = 1, cnt = N * (1 - N) / 2; j <= M; ++ j, cnt += N)
  31. {
  32. cin >> b[j].val;
  33. b[j].ind = j;
  34. ans += b[j].val * cnt;
  35. }
  36. sort(a + 1, a + N + 1, cmp);
  37. sort(b + 1, b + M + 1, cmp);
  38. sum_ind_b = M * (M + 1) / 2;
  39. for(LL i = 1, lim = 1; i <= N; ++ i)
  40. {
  41. while(lim <= M && a[i].val >= b[lim].val)
  42. sum_ind_b -= b[lim ++].ind;
  43. ans -= 2 * a[i].val * (a[i].ind * (M - lim + 1) - sum_ind_b);
  44. }
  45. sum_ind_a = 0;
  46. for(LL j = 1, lim = 1; j <= M; ++ j)
  47. {
  48. while(lim <= N && a[lim].val < b[j].val)
  49. sum_ind_a += a[lim ++].ind;
  50. ans += 2 * b[j].val * (sum_ind_a - (lim - 1) * b[j].ind);
  51. }
  52. cout << ans << endl;
  53. }
  54. return 0;
  55. }
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注