初学java,线段树 70 分 mle 3个点

P3372 【模板】线段树 1

Reaepita @ 2023-08-31 14:47:04

为什么会 mle ,求大佬解答

import java.util.Scanner;
class SegmentTree {
    long[] sum;
    long[] tag;
    SegmentTree(int size)
    {
        sum = new long[size * 4];
        tag = new long[size * 4];
    } 
    void pushUp(int rt)
    {
        sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
    }
    void pushDown(int rt, int l, int r)
    {
        int mid = l + r >> 1;
        if (tag[rt] != 0)
        {
            tag[rt << 1] += tag[rt];
            tag[rt << 1 | 1] += tag[rt];
            sum[rt << 1] += tag[rt] * (mid-l+1);
            sum[rt << 1 | 1] += tag[rt] * (r-mid);
            tag[rt] = 0;
        }
    }
    void build(int rt, int l, int r, int[] a)
    {
        if (l == r)
        {
            sum[rt] = a[l];
            return;
        }
        int mid = l + r >> 1;
        build(rt << 1, l, mid, a);
        build(rt << 1 | 1, mid+1, r, a);
        pushUp(rt);
    }
    void update(int rt, int l, int r, int L, int R, int val)
    {
        if (L <= l && r <= R)
        {
            tag[rt] += val;
            sum[rt] += val * (r-l+1);
            return;
        }
        pushDown(rt, l, r);
        int mid = l + r >> 1;
        if (L <= mid) update(rt << 1, l, mid, L, R, val);
        if (R > mid) update(rt << 1 | 1, mid+1, r, L, R, val);
        pushUp(rt);
    }
    long query(int rt, int l, int r, int L, int R)
    {
        if (L <= l && r <= R)
        {
            return sum[rt];
        }
        pushDown(rt, l, r);
        int mid = l + r >> 1;
        long ans = 0;
        if (L <= mid) ans += query(rt << 1, l, mid, L, R);
        if (R > mid) ans += query(rt << 1 | 1, mid+1, r, L, R);
        return ans;
    }
}
public class Main {
    public static void main(String[] args) {
        int n, m;
        Scanner cin = new Scanner(System.in);
        n = cin.nextInt();
        m = cin.nextInt();
        int[] a = new int[n+1];

        SegmentTree st = new SegmentTree(n);
        for (int i = 1; i <= n; i++)a[i] = cin.nextInt();
        st.build(1, 1, n, a);
        while (m-- != 0)
        {
            int op = cin.nextInt();
            if (op == 1)
            {
                int l = cin.nextInt();
                int r = cin.nextInt();
                int val = cin.nextInt();
                st.update(1, 1, n, l, r, val);
            }
            else
            {
                int l = cin.nextInt();
                int r = cin.nextInt();
                System.out.println(st.query(1, 1, n, l, r));
            }
        }
        cin.close();

    }
}

|