MnZn代码求调

P2839 [国家集训队] middle

mike05 @ 2024-11-23 21:23:18

样例已过,subtask 1已过 subtask 0 全 WA,求调。

#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
using ll = long long;
using ld = long double;
using pii = pair<int, int>;
const ll N = 2e4 + 10, M = N * 50;
ll cnt, root[N], n, T;
struct Edge
{
    ll a, id;
}mp[N];
struct Node
{
    ll l, r, lson, rson, val[3];                           //val0->sum   val1->lmax   val2->rmax
};
struct SEGTree
{
    Node tr[M];
    #define ls(u) tr[u].lson
    #define rs(u) tr[u].rson
    void pushup(ll u)
    {
        tr[u].val[0] = tr[ls(u)].val[0] + tr[rs(u)].val[0];
        tr[u].val[1] = max(tr[ls(u)].val[1], tr[ls(u)].val[0] + tr[rs(u)].val[1]);
        tr[u].val[2] = max(tr[rs(u)].val[2], tr[rs(u)].val[0] + tr[ls(u)].val[2]);
    }
    void build(ll &u, ll l, ll r)
    {
        u = ++ cnt;
        tr[u] = {l, r};
        if (l == r) return tr[u].val[1] = tr[u].val[2] = tr[u].val[0] = 1, void();
        ll mid = (l + r) >> 1;
        build(ls(u), l, mid); build(rs(u), mid + 1, r);
        pushup(u);
    }
    void insert(ll &u, ll v, ll x, ll k)
    {
        u = ++ cnt; tr[u] = tr[v];
        if (tr[u].l == tr[u].r) return tr[u].val[1] = tr[u].val[2] = tr[u].val[0] = k, void();
        ll mid = (tr[u].l + tr[u].r) >> 1;
        if (x <= mid) insert(ls(u), ls(v), x, k);
        else insert(rs(u), rs(v), x, k);
        pushup(u);
    }
    ll query(ll u, ll l, ll r, ll op)
    {
        if (tr[u].l >= l && tr[u].r <= r) return tr[u].val[op];
        ll mid = (tr[u].l + tr[u].r) >> 1, res = 0;
        if (op == 0)
        {
            if (l <= mid) res += query(ls(u), l, r, op);
            if (r > mid) res += query(rs(u), l, r, op);
        }
        else if (op == 1)
        {
            if (r <= mid) res = query(ls(u), l, r, op);
            else if (l > mid) res = max(res, query(rs(u), l, r, op));
            else res = max(query(ls(u), l, r, op), query(ls(u), l, mid, 0) + query(rs(u), l, r, op));
        }
        else
        {
            if (r <= mid) res = query(ls(u), l, r, op);
            else if (l > mid) res = max(res, query(rs(u), l, r, op));
            else res = max(query(rs(u), l, r, op), query(rs(u), mid + 1, r, 0) + query(ls(u), l, r, op));
        }
        return res;
    }
}seg;
bool operator < (Edge s, Edge t) { return s.a < t.a; }
bool check(ll k, ll a, ll b, ll c, ll d)
{
    ll sum = 0;
    if (c - 1 >= b + 1) sum += seg.query(root[k], b + 1, c - 1, 0);
    sum += seg.query(root[k], a, b, 2);
    sum += seg.query(root[k], c, d, 1);
    return sum >= 0;
}
int main()
{
    cin >> n;
    for (ll i = 1; i <= n; i ++ ) cin >> mp[i].a, mp[i].id = i;
    sort(mp + 1, mp + n + 1);
    seg.build(root[1], 1, n);
    for (ll i = 2; i <= n + 1; i ++ ) seg.insert(root[i], root[i - 1], mp[i - 1].id, -1);
    cin >> T;
    ll q[4], lst = 0;
    while (T -- )
    {
        for (ll i = 0; i < 4; i ++ ) cin >> q[i], q[i] = (q[i] + lst) % n;
        sort(q, q + 4);
        ll l = 1, r = n + 1, ans;
        while (l <= r)
        {
            ll mid = (l + r) >> 1;
            if (check(mid, q[0] + 1, q[1] + 1, q[2] + 1, q[3] + 1)) l = mid + 1, ans = mid;
            else r = mid - 1;
        }
        cout << (lst = mp[ans].a) << "\n";
    }
    cerr << "Time : " << clock() << " ms\n";
    return 0;
}

by mike05 @ 2024-11-23 21:45:02

此帖结。

改成这样就可以了

#include <bits/stdc++.h>
using namespace std;
#define x first
#define y second
using ll = long long;
using ld = long double;
using pii = pair<int, int>;
const ll N = 2e4 + 10, M = N * 50;
ll cnt, root[N], n, T;
struct Edge
{
    ll a, id;
}mp[N];
struct Node
{
    ll l, r, lson, rson, val[3];                           //val0->sum   val1->lmax   val2->rmax
};
struct SEGTree
{
    Node tr[M];
    #define ls(u) tr[u].lson
    #define rs(u) tr[u].rson
    void pushup(ll u)
    {
        tr[u].val[0] = tr[ls(u)].val[0] + tr[rs(u)].val[0];
        tr[u].val[1] = max(tr[ls(u)].val[1], tr[ls(u)].val[0] + tr[rs(u)].val[1]);
        tr[u].val[2] = max(tr[rs(u)].val[2], tr[rs(u)].val[0] + tr[ls(u)].val[2]);
    }
    void build(ll &u, ll l, ll r)
    {
        u = ++ cnt;
        tr[u] = {l, r};
        if (l == r) return tr[u].val[1] = tr[u].val[2] = tr[u].val[0] = 1, void();
        ll mid = (l + r) >> 1;
        build(ls(u), l, mid); build(rs(u), mid + 1, r);
        pushup(u);
    }
    void insert(ll &u, ll v, ll x, ll k)
    {
        u = ++ cnt; tr[u] = tr[v];
        if (tr[u].l == tr[u].r) return tr[u].val[1] = tr[u].val[2] = max(0ll, k), tr[u].val[0] = k, void();
        ll mid = (tr[u].l + tr[u].r) >> 1;
        if (x <= mid) insert(ls(u), ls(v), x, k);
        else insert(rs(u), rs(v), x, k);
        pushup(u);
    }
    ll query(ll u, ll l, ll r, ll op)
    {
        if (tr[u].l >= l && tr[u].r <= r) return tr[u].val[op];
        ll mid = (tr[u].l + tr[u].r) >> 1, res = 0;
        if (op == 0)
        {
            if (l <= mid) res += query(ls(u), l, r, op);
            if (r > mid) res += query(rs(u), l, r, op);
        }
        else if (op == 1)
        {
            if (r <= mid) res = query(ls(u), l, r, op);
            else if (l > mid) res = max(res, query(rs(u), l, r, op));
            else res = max(query(ls(u), l, r, op), query(ls(u), l, mid, 0) + query(rs(u), l, r, op));
        }
        else
        {
            if (r <= mid) res = query(ls(u), l, r, op);
            else if (l > mid) res = max(res, query(rs(u), l, r, op));
            else res = max(query(rs(u), l, r, op), query(rs(u), mid + 1, r, 0) + query(ls(u), l, r, op));
        }
        return res;
    }
}seg;
bool operator < (Edge s, Edge t) { return s.a < t.a; }
bool check(ll k, ll a, ll b, ll c, ll d)
{
    ll sum = 0;
    sum += seg.query(root[k], b, c, 0);
    sum += seg.query(root[k], a, b - 1, 2);
    sum += seg.query(root[k], c + 1, d, 1);
    return sum >= 0;
}
int main()
{
    cin >> n;
    for (ll i = 1; i <= n; i ++ ) cin >> mp[i].a, mp[i].id = i;
    sort(mp + 1, mp + n + 1);
    seg.build(root[1], 1, n);
    for (ll i = 2; i <= n + 1; i ++ ) seg.insert(root[i], root[i - 1], mp[i - 1].id, -1);
    // for (ll i = 1; i <= cnt; i ++ ) 
        // cout << seg.tr[i].l << " " << seg.tr[i].r << " " << seg.tr[i].val[0] << " " << seg.tr[i].val[1] << " " << seg.tr[i].val[2] << "\n";
    cin >> T;
    ll q[4], lst = 0;
    while (T -- )
    {
        for (ll i = 0; i < 4; i ++ ) cin >> q[i], q[i] = (q[i] + lst) % n;
        sort(q, q + 4);
        ll l = 1, r = n + 1, ans;
        while (l <= r)
        {
            ll mid = (l + r) >> 1;
            if (check(mid, q[0] + 1, q[1] + 1, q[2] + 1, q[3] + 1)) l = mid + 1, ans = mid;
            else r = mid - 1;
        }
        cout << (lst = mp[ans].a) << "\n";
    }
    cerr << "Time : " << clock() << " ms\n";
    return 0;
}

|