线段树10pts求调

P3372 【模板】线段树 1

lihugang @ 2024-10-10 00:13:57

只 AC 了第一个点,后面全 WA 了

#include <stdio.h>
#include <stdlib.h>
#include <utility>
#include <vector>

typedef struct _node {
    std::pair<unsigned, unsigned> interval;
    int value;
    int lazyAdd;
} node;

class SegmentTree {
    private:
    node * nodes;
    unsigned int length;

    public:
    SegmentTree(int * sourceData, unsigned int lengthOfSourceData) {
        length = lengthOfSourceData * 4;
        nodes = (node *)malloc(sizeof(node) * length);
        build(sourceData, 1, lengthOfSourceData, 1);
    }

    private:
    void build(int * sourceData, unsigned int left, unsigned int right, unsigned int nodeId) {
        nodes[nodeId].interval = std::make_pair(left, right);
        nodes[nodeId].lazyAdd = 0;

        if (left == right) {
            nodes[nodeId].value = sourceData[left];
            return;
        }

        unsigned int middle = (left + right) / 2;
        build(sourceData, left, middle, nodeId * 2);
        build(sourceData, middle + 1, right, nodeId * 2 + 1);

        nodes[nodeId].value = nodes[nodeId * 2].value + nodes[nodeId * 2 + 1].value;
    }

    int getSum(std::pair<unsigned, unsigned> &queryInterval, unsigned int nodeId) {

        pushdown(nodeId);

        if (isSubset(nodes[nodeId].interval, queryInterval)) {
            return nodes[nodeId].value;
        }

        unsigned int middle = (nodes[nodeId].interval.first + nodes[nodeId].interval.second) / 2;

        int ans = 0;
        if (queryInterval.first <= middle) ans += getSum(queryInterval, nodeId * 2);
        if (queryInterval.second > middle) ans += getSum(queryInterval, nodeId * 2 + 1);

        return ans;
    }

    void add(std::pair <unsigned, unsigned> &operationInterval, int changeNumber, unsigned int nodeId) {
        if (isSubset(nodes[nodeId].interval, operationInterval)) {
            nodes[nodeId].lazyAdd += changeNumber;
            pushdown(nodeId);
        } else {
            pushdown(nodeId);
            unsigned int middle = (nodes[nodeId].interval.first + nodes[nodeId].interval.second) / 2;
            if (operationInterval.first <= middle) add(operationInterval, changeNumber, nodeId * 2);
            if (operationInterval.second > middle) add(operationInterval, changeNumber, nodeId * 2 + 1);
            nodes[nodeId].value = nodes[nodeId * 2].value + nodes[nodeId * 2 + 1].value;
        }
    }

    void pushdown(unsigned int nodeId) {
        if (nodes[nodeId].lazyAdd) {
            nodes[nodeId].value += nodes[nodeId].lazyAdd * getLengthOfInterval(nodes[nodeId].interval);
            if (getLengthOfInterval(nodes[nodeId].interval) > 1) {
                nodes[nodeId * 2].lazyAdd += nodes[nodeId].lazyAdd;
                nodes[nodeId * 2 + 1].lazyAdd += nodes[nodeId].lazyAdd;
            }
            nodes[nodeId].lazyAdd = 0;
        }
    }

    inline bool isSubset(std::pair<unsigned, unsigned> &a, std::pair<unsigned, unsigned> &b) {
        return a.first >= b.first && a.second <= b.second;
    }

    inline unsigned int getLengthOfInterval(std::pair <unsigned, unsigned> &interval) {
        return interval.second - interval.first + 1;
    }

    public:
    int getSum(std::pair<unsigned, unsigned> queryInterval) {
        return getSum(queryInterval, 1);
    }

    void add(std::pair<unsigned, unsigned> operationInterval, int changeNumber) {
        return add(operationInterval, changeNumber, 1);
    }
};

int main(void) {

    #ifdef DEBUG
    freopen("3372.in", "r", stdin);
    #endif

    int countOfNumbers, countOfOperations;
    scanf("%d %d", &countOfNumbers, &countOfOperations);

    int sourceData[100008];

    for (int i = 1; i <= countOfNumbers; i++) {
        scanf("%d", &sourceData[i]);
    }

    SegmentTree segmentTree(sourceData, countOfNumbers);

    for (int i = 0; i < countOfOperations; i++) {
        int operation;
        scanf("%d", &operation);

        unsigned int left, right;

        if (operation == 1) {
            int changeNumber;
            scanf("%u %u %d", &left, &right, &changeNumber);
            segmentTree.add(std::make_pair(left, right), changeNumber);
        } else {
            scanf("%u %u", &left, &right);
            printf("%d\n", segmentTree.getSum(std::make_pair(left, right)));
        }
    }
    return 0;
}

by Winalways @ 2024-10-10 00:14:59

还不睡啊哥


by lihugang @ 2024-10-10 00:22:35

@luosabi321 不 AC 线段树睡不着,能不能帮忙看看错在哪了,谢谢


by lihugang @ 2024-10-10 00:23:13

我感觉应该错在懒惰标记上了


by lihugang @ 2024-10-10 23:45:31

问题已解决

将`add`函数改为这样就可以了: ```c++ void add(std::pair <unsigned, unsigned> &operationInterval, long long changeNumber, unsigned int nodeId) { if (isSubset(nodes[nodeId].interval, operationInterval)) { nodes[nodeId].lazyAdd += changeNumber; pushdown(nodeId); } else { pushdown(nodeId); pushdown(nodeId * 2); pushdown(nodeId * 2 + 1); unsigned int middle = (nodes[nodeId].interval.first + nodes[nodeId].interval.second) / 2; if (operationInterval.first <= middle) add(operationInterval, changeNumber, nodeId * 2); if (operationInterval.second > middle) add(operationInterval, changeNumber, nodeId * 2 + 1); nodes[nodeId].value = nodes[nodeId * 2].value + nodes[nodeId * 2 + 1].value; } } ``` 多加了对子节点的`pushdown` 另外,还需要把 `int` 改为 `long long` 才能过后$3$个点

|