1.16 小测

Sunny_r

2020-01-16 22:28:05

Personal

今天下午老师找了一道题让做一个半小时,当做小练,好像也许或许大概可能和以前一道题有点像,一眼看上去max, min, r-l+1,哇塞,都见过耶。两眼看上去,em,不太会,不行,开始思考。过了一会,算了,20pts挺好的,暴力开溜。。

题面:

\displaystyle\sum_{i=1}^{n}\displaystyle\sum_{j = i}^{n}(j - i +1) * mx[i][j] * mn[i][j],输出对1e9取模的结果,n <= 5e5

考场思路:先把每个点按照权值从小到大排序,保证min,然后在线段树上扫一扫,找到合法区间的左右边界,然而发现很难维护,凉凉~

正解:分治就好了啦,每次将左端点从mid扫到l,并维护左半区间的MnMx,在右半区间定义jk,其中j表示右半区间最后一个mn>=Mn的下标,k表示右半区间最后一个mx<=Mx的下标,考虑到jk肯定是单调不减的(因为左半部分以及右半部分的minmax肯定都是单调的),所以复杂度不会退化。

所以我们考虑如何统计贡献呢?这里就要分情况啦,假设j<=k

首先设mn[i]表示从mid + 1向右到该点的最小值,mx[i]表示从mid +1向右到该点的最大值

1.首先对于mid+1j的区间,贡献为

\displaystyle \sum_{i=mid+1}^{j}Mn * Mx * (i - l +1) = Mn * Mx * ((mid + j) * (j - mid + 1) / 2 - (j - mid + 1) * (l - 1))

2.对于j+1k的区间,贡献为

Mx * \displaystyle \sum_{i=j+1}^{k}mn[i] * (i - l + 1)=Mx * \displaystyle \sum_{i=j+1}^{k}(mn[i] * i - mn[i] * (l - 1))

考虑维护mn * i的前缀和记为smni,维护mn的前缀和记为smn

所以上述贡献即为

Mx * (smni[k] - smni[j] - (smn[k] - smn[j]) * (l - 1))

3.对于k+1到区间右端点(记为R),贡献为

\displaystyle\sum_{i = k + 1} ^ {R}mn[i] * mx[i] * (i - l + 1)=\displaystyle\sum_{i = k + 1} ^ {R}mn[i] * mx[i] * i - mn[i] * mx[i] * (l - 1)

考虑维护mn * mx * i的前缀和记为nxi,维护mx * mn的前缀和记为nx

所以上述贡献即为

nxi[R] - nxi[k] - (nx[R] - nx[k]) * (l - 1)

讨论到这里就愉快的结束啦

那么对于j > k的情况肿么办捏?那不是一样吗...

再维护一个mx * i的前缀和smxi以及mx的前缀和smx就好了啦。

#include <iostream>
#include <cstdio>
#define ll long long
#define int long long
using namespace std;
const int N = 5e5 + 5, inf = 0x3f3f3f3f, mod = 1e9;
int n;
ll mn[N], mx[N], ans, a[N], smn[N], smx[N], smni[N], smxi[N], nx[N], nxi[N];
inline int read()
{
    int x = 0, f = 1; char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') {x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
    return x * f;
}
void solve(int L, int R)
{
    if(L == R) return ans = (ans + a[L] * a[L] % mod) % mod, void();//*1而不是*L 
    int mid = (L + R) >> 1;
    solve(L, mid); solve(mid + 1, R);
    mn[mid] = mx[mid] = smn[mid] = smx[mid] = nx[mid] = nxi[mid] = smni[mid] = smxi[mid] = 0;//
    mn[mid + 1] = mx[mid + 1] = smn[mid + 1] = smx[mid + 1] = a[mid + 1];
    smni[mid + 1] = smxi[mid + 1] = a[mid + 1] * (mid + 1) % mod;
    nx[mid + 1] = a[mid + 1] * a[mid + 1] % mod;
    nxi[mid + 1] = a[mid + 1] * a[mid + 1] % mod * (mid + 1) % mod;
    for(int w = mid + 2; w <= R; w ++)
    {
        mn[w] = min(a[w], mn[w - 1]); mx[w] = max(mx[w - 1], a[w]);
        smn[w] = (smn[w - 1] + mn[w]) % mod; smx[w] = (smx[w - 1] + mx[w]) % mod;
        smni[w] = (smni[w - 1] + mn[w] * w % mod) % mod; smxi[w] = (smxi[w - 1] + mx[w] * w % mod) % mod;
        nx[w] = (nx[w - 1] + mn[w] * mx[w] % mod) % mod; nxi[w] = (nxi[w - 1] + mn[w] * mx[w] % mod * w % mod) % mod;
    }
    //smx, smn......因为取模了所以要转正 
    ll Mn = inf, Mx = 0, j = mid + 1, k = mid + 1, res;
    for(int l = mid; l >= L; l --)
    {
        if(j < mid + 1) j = mid + 1;
        if(k < mid + 1) k = mid + 1;
        Mn = min(Mn, a[l]); Mx = max(Mx, a[l]);
        while(j <= R && mn[j] >= Mn) j ++; j --;
        while(k <= R && mx[k] <= Mx) k ++; k --;
        if(j <= k)
        {
            res = Mn * Mx % mod * (((mid + 1 + j) * (j - mid) / 2) % mod - (j - mid) * (l - 1) % mod + mod) % mod;
            ans = ((ans + res) % mod + mod) % mod;
            res = Mx * (((smni[k] - smni[j] + mod) % mod - (smn[k] - smn[j] + mod) % mod * (l - 1) % mod + mod) % mod) % mod;
            ans = ((ans + res) % mod + mod) % mod;
            res = ((nxi[R] - nxi[k] + mod) % mod - (nx[R] - nx[k] + mod) % mod * (l - 1) % mod + mod) % mod;
            ans = ((ans + res) % mod + mod) % mod;
        }
        else
        {
            res = Mn * Mx % mod * (((mid + 1 + k) * (k - mid) / 2) % mod - (k - mid) * (l - 1) % mod + mod) % mod;
            ans = ((ans + res) % mod + mod) % mod;
            res = Mn * (((smxi[j] - smxi[k] + mod) % mod - (smx[j] - smx[k] + mod) % mod * (l - 1) % mod + mod) % mod) % mod;
            ans = ((ans + res) % mod + mod) % mod;
            res = ((nxi[R] - nxi[j] + mod) % mod - (nx[R] - nx[j] + mod) % mod * (l - 1) % mod + mod) % mod;
            ans = ((ans + res) % mod + mod) % mod;
        }
    }
}
signed main()
{
    freopen("C.in", "r", stdin);
    freopen("C.out", "w", stdout);
    n = read();
    for(int i = 1; i <= n; i ++) a[i] = read();
    solve(1, n); printf("%lld\n", ans);
    fclose(stdin);
    fclose(stdout);
    return 0;
}
/*
4
2 4 1 4
*/