最后三个点MLE,求教

P3372 【模板】线段树 1

睡眼惺忪 @ 2023-08-13 16:42:41

import java.util.Scanner;
import java.util.Set;
import java.util.stream.Collectors;

public class Main {

    static class TreeNode{
        public int l;
        public int r;
        public Long val;
        public Long flag;
        public TreeNode(int l, int r, Long val, Long flag) {
            this.l = l;
            this.r = r;
            this.val = val;
            this.flag = flag;
        }
    }

    static TreeNode[] tree;

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        int m = scanner.nextInt();
        Long[] num = new Long[n + 1];
        for (int i = 1; i <= n; i++) {
            num[i] = scanner.nextLong();
        }
        tree = new TreeNode[4 * n + 2];
        for (int i = 0; i <= 4 * n + 1; i++) {
            tree[i] = new TreeNode(0,0,0L,0L);
        }
        build(1,n,num,1);
        scanner.nextLine();
        for (int i = 0; i < m; i++) {
            String[] s = scanner.nextLine().split(" ");
            if (s[0].equals("1")) {
                add(Integer.valueOf(s[1]), Integer.valueOf(s[2]),Long.valueOf(s[3]),1);
            }
            if (s[0].equals("2")) {
                Long sum = sum(Integer.valueOf(s[1]), Integer.valueOf(s[2]),1);
                System.out.println(sum);
            }
        }
    }

    private static void add(int l, int r, Long k, int cur) {
        if (l <= tree[cur].l && r >= tree[cur].r) {
            tree[cur].val += k * (tree[cur].r - tree[cur].l + 1);
            tree[cur].flag += k;
            return;
        }
        spread(cur);
        int mid = tree[cur].l + tree[cur].r >> 1;
        if (mid >= l) {
            add(l,r,k,cur * 2);
        }
        if (mid < r) {
            add(l, r, k,cur * 2 + 1);
        }
        tree[cur].val = tree[cur * 2].val + tree[cur * 2 + 1].val;
    }

    private static Long sum(int l, int r, int cur) {
        if (l <= tree[cur].l && r >= tree[cur].r) {
            return tree[cur].val;
        }
        spread(cur);
        int mid = tree[cur].l + tree[cur].r >> 1;
        Long left = 0L;
        Long right = 0L;
        if (mid >= l) {
            left = sum(l, r, cur * 2);
        }
        if (mid < r) {
            right = sum(l, r, cur * 2 + 1);
        }
        return left + right;
    }

    private static void build(int l, int r, Long[] num, int cur) {
        tree[cur].l = l;
        tree[cur].r = r;
        if (l == r) {
            tree[cur].val = num[l];
            return;
        }
        int mid = l + r >> 1;
        build(l,mid,num,cur * 2);
        build(mid + 1,r,num,cur * 2 + 1);
        tree[cur].val = tree[cur * 2].val + tree[cur * 2 + 1].val;
    }

    private static void spread(int cur) {
        if (tree[cur].flag != 0L) {
            tree[cur * 2].val += tree[cur].flag * (tree[cur * 2].r - tree[cur * 2].l + 1);
            tree[cur * 2 + 1].val += tree[cur].flag * (tree[cur * 2 + 1].r - tree[cur * 2 + 1].l + 1);
            tree[cur * 2].flag += tree[cur].flag;
            tree[cur * 2 + 1].flag += tree[cur].flag;
            tree[cur].flag = 0L;
        }
    }
}

|